diff options
353 files changed, 11927 insertions, 6188 deletions
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# RBE requires a strong hash function, such as SHA256. +# Ensure a strong hash function. startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 # Build with C++17. @@ -20,27 +20,3 @@ build --cxxopt=-std=c++17 # Display the current git revision in the info block. build --stamp --workspace_status_command tools/workspace_status.sh - -# Enable remote execution so actions are performed on the remote systems. -build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:remote --bes_backend=buildeventservice.googleapis.com -build:remote --bes_results_url="https://source.cloud.google.com/results/invocations" -build:remote --bes_timeout=600s -build:remote --project_id=gvisor-rbe -build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance - -# Enable authentication. This will pick up application default credentials by -# default. You can use --google_credentials=some_file.json to use a service -# account credential instead. -build:remote --google_default_credentials=true -build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" - -# Add a custom platform and toolchain that builds in a privileged docker -# container, which is required by our syscall tests. -build:remote --host_platform=//tools/bazeldefs:rbe_ubuntu1604 -build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default -build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604 -build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604 -build:remote --crosstool_top=@rbe_default//cc:toolchain -build:remote --jobs=100 -build:remote --remote_timeout=3600 diff --git a/.buildkite/hooks/post-command b/.buildkite/hooks/post-command new file mode 100644 index 000000000..b0396bec7 --- /dev/null +++ b/.buildkite/hooks/post-command @@ -0,0 +1,74 @@ +# Upload test logs on failure, if there are any. +if test "${BUILDKITE_COMMAND_EXIT_STATUS}" -ne "0"; then + # Generate a metafile that ends with .output, and contains all the + # test failures that have been uploaded. These will all be sorted and + # aggregated by a failure stage in the build pipeline. + declare output=$(mktemp "${BUILDKITE_JOB_ID}".XXXXXX.output) + make -s testlogs 2>/dev/null | grep // | sort | uniq | ( + declare log_count=0 + while read target log; do + if test -z "${target}"; then + continue + fi + + # N.B. If *all* tests fail due to some common cause, then we will + # end up spending way too much time uploading logs. Instead, we just + # upload the first 10 and stop. That is hopefully enough to debug. + # + # We include this test in the metadata, but note that we cannot + # upload the actual test logs. The user should rerun locally. + log_count=$((${log_count}+1)) + if test "${log_count}" -ge 10; then + echo " * ${target} (no upload)" | tee -a "${output}" + else + buildkite-agent artifact upload "${log}" + echo " * [${target}](artifact://${log#/})" | tee -a "${output}" + fi + done + ) + + # Upload if we had outputs. + if test -s "${output}"; then + buildkite-agent artifact upload "${output}" + fi + rm -rf "${output}" + + # Attempt to clear the cache and shut down. + make clean || echo "make clean failed with code $?" + make bazel-shutdown || echo "make bazel-shutdown failed with code $?" +fi + +# Upload all profiles, and include in an annotation. +if test -d /tmp/profile; then + # Same as above. + declare profile_output=$(mktemp "${BUILDKITE_JOB_ID}".XXXXXX.profile_output) + for file in $(find /tmp/profile -name \*.pprof -print 2>/dev/null | sort); do + # Generate a link to speedscope, with a URL-encoded link to the BuildKite + # artifact location. Note that we use do a fixed URL encode below, since + # the link can be uniquely determined. If the storage location changes, + # this schema may break and these links may stop working. The artifacts + # uploaded however, will still work just fine. + profile_name="${file#/tmp/profile/}" + public_url="https://storage.googleapis.com/gvisor-buildkite/${BUILDKITE_BUILD_ID}/${BUILDKITE_JOB_ID}/${file#/}" + encoded_url=$(jq -rn --arg x "${public_url}" '$x|@uri') + encoded_title=$(jq -rn --arg x "${profile_name}" '$x|@uri') + profile_url="https://speedscope.app/#profileURL=${encoded_url}&title=${encoded_title}" + buildkite-agent artifact upload "${file}" + echo " * [${profile_name}](${profile_url}) ([pprof](artifact://${file#/}))" | tee -a "${profile_output}" + done + + # Upload if we had outputs. + if test -s "${profile_output}"; then + buildkite-agent artifact upload "${profile_output}" + fi + rm -rf "${profile_output}" + + # Remove stale profiles, which may be owned by root. + sudo rm -rf /tmp/profile +fi + +# Kill any running containers (clear state). +CONTAINERS="$(docker ps -q)" +if ! test -z "${CONTAINERS}"; then + docker container kill ${CONTAINERS} 2>/dev/null || true +fi diff --git a/.buildkite/hooks/pre-command b/.buildkite/hooks/pre-command new file mode 100644 index 000000000..4f41fe021 --- /dev/null +++ b/.buildkite/hooks/pre-command @@ -0,0 +1,30 @@ +# Install packages we need. Docker must be installed and configured, +# as should Go itself. We just install some extra bits and pieces. +function install_pkgs() { + while true; do + if sudo apt-get update && sudo apt-get install -y "$@"; then + break + fi + done +} +install_pkgs graphviz jq curl binutils gnupg gnupg-agent linux-libc-dev \ + apt-transport-https ca-certificates software-properties-common + +# Setup for parallelization with PARTITION and TOTAL_PARTITIONS. +export PARTITION=${BUILDKITE_PARALLEL_JOB:-0} +PARTITION=$((${PARTITION}+1)) # 1-indexed, but PARALLEL_JOB is 0-indexed. +export TOTAL_PARTITIONS=${BUILDKITE_PARALLEL_JOB_COUNT:-1} + +# Ensure Docker has experimental enabled. +EXPERIMENTAL=$(sudo docker version --format='{{.Server.Experimental}}') +if test "${EXPERIMENTAL}" != "true"; then + make sudo TARGETS=//runsc:runsc ARGS="install --experimental=true" + sudo systemctl restart docker +fi + +# Helper for benchmarks, based on the branch. +if test "${BUILDKITE_BRANCH}" = "master"; then + export BENCHMARKS_OFFICIAL=true +else + export BENCHMARKS_OFFICIAL=false +fi
\ No newline at end of file diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index 337f45870..ba054319c 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -1,5 +1,6 @@ _templates: common: &common + timeout_in_minutes: 30 retry: automatic: - exit_status: -1 @@ -9,7 +10,8 @@ _templates: steps: # Run basic smoke tests before preceding to other tests. - - label: ":fire: Smoke tests" + - <<: *common + label: ":fire: Smoke tests" command: make smoke-tests - wait @@ -17,12 +19,15 @@ steps: - <<: *common label: ":golang: Go branch" commands: - - rm -rf bazel-bin/gopath - - make build TARGETS="//:gopath" - - tools/go_branch.sh + - make go - git checkout go && git clean -f - go build ./... + # Release workflow. + - <<: *common + label: ":ship: Release tests" + commands: make release + # Basic unit tests. - <<: *common label: ":test_tube: Unit tests" @@ -79,25 +84,66 @@ steps: label: ":hammer: Packetimpact tests" command: make packetimpact-tests - # Start heavy runtime tests. - - wait + # Runtime tests. - <<: *common label: ":php: PHP runtime tests" - command: make php7.3.6-runtime-tests + command: make php7.3.6-runtime-tests_vfs2 parallelism: 10 - <<: *common label: ":java: Java runtime tests" - command: make java11-runtime-tests + command: make java11-runtime-tests_vfs2 parallelism: 40 - <<: *common label: ":golang: Go runtime tests" - command: make go1.12-runtime-tests + command: make go1.12-runtime-tests_vfs2 parallelism: 10 - <<: *common label: ":node: NodeJS runtime tests" - command: make nodejs12.4.0-runtime-tests + command: make nodejs12.4.0-runtime-tests_vfs2 parallelism: 10 - <<: *common label: ":python: Python runtime tests" + command: make python3.7.3-runtime-tests_vfs2 + parallelism: 10 + + # Runtime tests (VFS1). + - <<: *common + label: ":php: PHP runtime tests (VFS1)" + command: make php7.3.6-runtime-tests + parallelism: 10 + if: build.message =~ /VFS1/ || build.branch == "master" + - <<: *common + label: ":java: Java runtime tests (VFS1)" + command: make java11-runtime-tests + parallelism: 40 + if: build.message =~ /VFS1/ || build.branch == "master" + - <<: *common + label: ":golang: Go runtime tests (VFS1)" + command: make go1.12-runtime-tests + parallelism: 10 + if: build.message =~ /VFS1/ || build.branch == "master" + - <<: *common + label: ":node: NodeJS runtime tests (VFS1)" + command: make nodejs12.4.0-runtime-tests + parallelism: 10 + if: build.message =~ /VFS1/ || build.branch == "master" + - <<: *common + label: ":python: Python runtime tests (VFS1)" command: make python3.7.3-runtime-tests parallelism: 10 + if: build.message =~ /VFS1/ || build.branch == "master" + + # The final step here will aggregate data uploaded by all other steps into an + # annotation that will appear at the top of the build, with useful information. + # + # See .buildkite/summarize.sh and .buildkite/hooks/post-command for more. + - wait + - <<: *common + label: ":yawning_face: Wait" + command: "true" + key: "wait" + - <<: *common + label: ":thisisfine: Summarize" + command: .buildkite/summarize.sh + allow_dependency_failure: true + depends_on: "wait" diff --git a/.buildkite/summarize.sh b/.buildkite/summarize.sh new file mode 100755 index 000000000..ddf8c9ad4 --- /dev/null +++ b/.buildkite/summarize.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeou pipefail + +# This script collects metadata fragments produced by individual test shards in +# .buildkite/hooks/postcommand, and aggregates these into a single annotation +# that is posted to the build. In the future, this will include coverage. + +# Start the summary. +declare summary +declare status +summary=$(mktemp --tmpdir summary.XXXXXX) +status="info" + +# Download all outputs. +declare outputs +outputs=$(mktemp -d --tmpdir outputs.XXXXXX) +if buildkite-agent artifact download '**/*.output' "${outputs}"; then + status="error" + echo "## Failures" >> "${summary}" + find "${outputs}" -type f -print | xargs -r -n 1 cat | sort >> "${summary}" +fi +rm -rf "${outputs}" + +# Attempt to find profiles, if there are any. +declare profiles +profiles=$(mktemp -d --tmpdir profiles.XXXXXX) +if buildkite-agent artifact download '**/*.profile_output' "${profiles}"; then + echo "## Profiles" >> "${summary}" + find "${profiles}" -type f -print | xargs -r -n 1 cat | sort >> "${summary}" +fi +rm -rf "${profiles}" + +# Upload the final annotation. +if [[ -s "${summary}" ]]; then + cat "${summary}" | buildkite-agent annotate --style "${status}" +fi +rm -rf "${summary}" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md deleted file mode 100644 index 264b4e9fa..000000000 --- a/.github/pull_request_template.md +++ /dev/null @@ -1,5 +0,0 @@ -* [ ] Have you followed the guidelines in [CONTRIBUTING.md](../blob/master/CONTRIBUTING.md)? -* [ ] Have you formatted and linted your code? -* [ ] Have you added relevant tests? -* [ ] Have you added appropriate Fixes & Updates references? -* [ ] If yes, please erase all these lines! diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e28e46352..270aaf034 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,13 +1,15 @@ +# This workflow builds the source code, extracts nogo annotations and +# posts them to GitHub, if applicable. This leverages the fact that the +# workflow token has appropriate permissions to do so, and attempts to +# leverage the GitHub workflow caches. name: "Build" -on: +"on": push: branches: - master - - feature/** pull_request: branches: - - master - - feature/** + - "**" jobs: default: @@ -22,7 +24,7 @@ jobs: ${{ runner.os }}-bazel- - run: make - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..." - - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ nogo" + - run: make run TARGETS="//tools/github" ARGS="-path=bazel-bin/ -path=bazel-out/ nogo" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3a6a592d1..e62991691 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -1,12 +1,12 @@ +# This workflow generates the Go branch. Note that this does not test the Go +# branch, as this is rolled into the main continuous integration pipeline. This +# workflow simply generates and pushes the branch, as long as appropriate +# permissions are available. name: "Go" -on: +"on": push: branches: - master - pull_request: - branches: - - master - - feature/** jobs: generate: @@ -19,20 +19,13 @@ jobs: else echo ::set-output name=has_token::false fi - - run: | - jq -nc '{"state": "pending", "context": "go tests"}' | \ - curl -sL -X POST -d @- \ - -H "Content-Type: application/json" \ - -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ - "${{ github.event.pull_request.statuses_url }}" - if: github.event_name == 'pull_request' - uses: actions/checkout@v2 - if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true' + if: steps.setup.outputs.has_token == 'true' with: fetch-depth: 0 token: '${{ secrets.GO_TOKEN }}' - uses: actions/checkout@v2 - if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true' + if: steps.setup.outputs.has_token != 'true' with: fetch-depth: 0 - uses: actions/setup-go@v2 @@ -50,32 +43,7 @@ jobs: key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }} restore-keys: | ${{ runner.os }}-bazel- - # Create gopath to merge the changes. The first execution will create - # symlinks to the cache, e.g. bazel-bin. Once the cache is setup, delete - # old gopath files that may exist from previous runs (and could contain - # files that are now deleted). Then run gopath again for good. + - run: make go - run: | - make build TARGETS="//:gopath" - rm -rf bazel-bin/gopath - make build TARGETS="//:gopath" - - run: tools/go_branch.sh - - run: git checkout go && git clean -f - - run: go build ./... - - if: github.event_name == 'push' - run: | git remote add upstream "https://github.com/${{ github.repository }}" git push upstream go:go - - if: ${{ success() && github.event_name == 'pull_request' }} - run: | - jq -nc '{"state": "success", "context": "go tests"}' | \ - curl -sL -X POST -d @- \ - -H "Content-Type: application/json" \ - -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ - "${{ github.event.pull_request.statuses_url }}" - - if: ${{ failure() && github.event_name == 'pull_request' }} - run: | - jq -nc '{"state": "failure", "context": "go tests"}' | \ - curl -sL -X POST -d @- \ - -H "Content-Type: application/json" \ - -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ - "${{ github.event.pull_request.statuses_url }}" diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml index c53185620..3bd883035 100644 --- a/.github/workflows/issue_reviver.yml +++ b/.github/workflows/issue_reviver.yml @@ -1,5 +1,7 @@ +# This workflow revives issues that are still referenced in the code, and may +# have been accidentally closed or marked stale. name: "Issue reviver" -on: +"on": schedule: - cron: '0 0 * * *' diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index c09f7eb36..3a19065e1 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -1,5 +1,6 @@ +# Labeler labels incoming pull requests. name: "Labeler" -on: +"on": - pull_request jobs: diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 0b31fecf5..3a4aa22e2 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,5 +1,7 @@ +# The stale workflow closes stale issues and pull requests, unless specific +# tags have been applied in order to keep them open. name: "Close stale issues" -on: +"on": schedule: - cron: "0 0 * * *" diff --git a/.gitignore b/.gitignore index 95fe857dd..a2a3fd508 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ # Generated bazel symlinks. /bazel-* # Generated build event file. -/.build_events.json
\ No newline at end of file +/.build_events.json +# Generated repository. +/repo +/repo.key
\ No newline at end of file @@ -1,5 +1,6 @@ load("//tools:defs.bzl", "build_test", "gazelle", "go_path") load("//tools/nogo:defs.bzl", "nogo_config") +load("//tools/yamltest:defs.bzl", "yaml_test") load("//website:defs.bzl", "doc") package(licenses = ["notice"]) @@ -50,6 +51,24 @@ doc( weight = "99", ) +yaml_test( + name = "nogo_config_test", + srcs = glob(["nogo*.yaml"]), + schema = "//tools/nogo:config-schema.json", +) + +yaml_test( + name = "github_workflows_test", + srcs = glob([".github/workflows/*.yml"]), + schema = "@github_workflow_schema//file", +) + +yaml_test( + name = "buildkite_pipelines_test", + srcs = glob([".buildkite/*.yaml"]), + schema = "@buildkite_pipeline_schema//file", +) + # The sandbox filegroup is used for sandbox-internal dependencies. package_group( name = "sandbox", @@ -67,12 +86,15 @@ build_test( "//test/benchmarks/base:startup_test", "//test/benchmarks/base:size_test", "//test/benchmarks/base:sysbench_test", - "//test/benchmarks/database:database_test", + "//test/benchmarks/database:redis_test", "//test/benchmarks/fs:bazel_test", "//test/benchmarks/fs:fio_test", - "//test/benchmarks/media:media_test", - "//test/benchmarks/ml:ml_test", - "//test/benchmarks/network:network_test", + "//test/benchmarks/media:ffmpeg_test", + "//test/benchmarks/ml:tensorflow_test", + "//test/benchmarks/network:httpd_test", + "//test/benchmarks/network:nginx_test", + "//test/benchmarks/network:node_test", + "//test/benchmarks/network:ruby_test", ], ) @@ -102,7 +124,9 @@ go_path( "//pkg/sentry/kernel/memevent", "//pkg/tcpip/adapters/gonet", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/muxed", + "//pkg/tcpip/link/pipe", "//pkg/tcpip/link/sharedmem", "//pkg/tcpip/link/sharedmem/pipe", "//pkg/tcpip/link/sharedmem/queue", @@ -14,27 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Helpful pretty-printer. -ifeq (0,$(MAKELEVEL)) -OPENLAST := || (rc=$$?; echo '^^^ +++' >&2; exit $$rc) -else -OPENLAST := -endif -CMDLINE := $(shell cut -d '' -f2- /proc/$$PPID/cmdline | sed 's|\x00| |g') -submake = echo '--- make $1' >&2 && \ - $(MAKE) -s $1 && \ - echo '--- make $(CMDLINE) (resume)' >&2 \ - $(OPENLAST) - -# Described below. -OPTIONS := -STARTUP_OPTIONS := -TARGETS := //runsc -ARGS := - default: runsc .PHONY: default +# Header for debugging (used by other macros). +header = echo --- $(1) >&2 + +# Make hacks. +EMPTY := +SPACE := $(EMPTY) $(EMPTY) +SHELL = /bin/bash + ## usage: make <target> ## or ## make <build|test|copy|run|sudo> STARTUP_OPTIONS="..." OPTIONS="..." TARGETS="..." ARGS="..." @@ -46,7 +36,6 @@ default: runsc ## requirements. ## ## There are common arguments that may be passed to targets. These are: -## STARTUP_OPTIONS - Bazel startup options. ## OPTIONS - Build or test options. ## TARGETS - The bazel targets. ## ARGS - Arguments for run or sudo. @@ -57,7 +46,7 @@ default: runsc ## make build OPTIONS="" TARGETS="//runsc"' ## help: ## Shows all targets and help from the Makefile (this message). - @grep --no-filename -E '^([a-z.A-Z_-]+:.*?|)##' $(MAKEFILE_LIST) | \ + @grep --no-filename -E '^([a-z.A-Z_%-]+:.*?|)##' $(MAKEFILE_LIST) | \ awk 'BEGIN {FS = "(:.*?|)## ?"}; { \ if (length($$1) > 0) { \ printf " \033[36m%-20s\033[0m %s\n", $$1, $$2; \ @@ -65,17 +54,34 @@ help: ## Shows all targets and help from the Makefile (this message). printf "%s\n", $$2; \ } \ }' + build: ## Builds the given $(TARGETS) with the given $(OPTIONS). E.g. make build TARGETS=runsc -test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test -copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp -run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version -sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v -.PHONY: help build test copy run sudo + @$(call build,$(OPTIONS) $(TARGETS)) +.PHONY: build + +test: ## Tests the given $(TARGETS) with the given $(OPTIONS). E.g. make test TARGETS=pkg/buffer:buffer_test + @$(call test,$(OPTIONS) $(TARGETS)) +.PHONY: test + +copy: ## Copies the given $(TARGETS) to the given $(DESTINATION). E.g. make copy TARGETS=runsc DESTINATION=/tmp + @$(call copy,$(TARGETS),$(DESTINATION)) +.PHONY: copy + +run: ## Runs the given $(TARGETS), built with $(OPTIONS), using $(ARGS). E.g. make run TARGETS=runsc ARGS=-version + @$(call run,$(TARGETS),$(ARGS)) +.PHONY: run + +sudo: ## Runs the given $(TARGETS) as per run, but using "sudo -E". E.g. make sudo TARGETS=test/root:root_test ARGS=-test.v + @$(call sudo,$(TARGETS),$(ARGS)) +.PHONY: sudo + +# Load image helpers. +include tools/images.mk # Load all bazel wrappers. # # This file should define the basic "build", "test", "run" and "sudo" rules, in -# addition to the $(BRANCH_NAME) variable. +# addition to the $(BRANCH_NAME) and $(BUILD_ROOTS) variables. ifneq (,$(wildcard tools/google.mk)) include tools/google.mk else @@ -83,32 +89,76 @@ include tools/bazel.mk endif ## -## Docker image targets. -## -## Images used by the tests must also be built and available locally. -## The canonical test targets defined below will automatically load -## relevant images. These can be loaded or built manually via these -## targets. +## Development helpers and tooling. ## -## (*) Note that you may provide an ARCH parameter in order to build -## and load images from an alternate archiecture (using qemu). When -## bazel is run as a server, this has the effect of running an full -## cross-architecture chain, and can produce cross-compiled binaries. +## These targets faciliate local development by automatically +## installing and configuring a runtime. Several variables may +## be used here to tweak the installation: +## RUNTIME - The name of the installed runtime (default: branch). +## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME). +## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc). +## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs). +## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%). ## -define images -$(1)-%: ## Image tool: $(1) a given image (also may use 'all-images'). - @$(call submake,-C images $$@) -endef -rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'. -$(eval $(call images,rebuild)) -push-...: ## Push the given image. Also may use 'push-all-images'. -$(eval $(call images,push)) -pull-...: ## Pull the given image. Also may use 'pull-all-images'. -$(eval $(call images,pull)) -load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'. -$(eval $(call images,load)) -list-images: ## List all available images. - @$(call submake, -C images $$@) +ifeq (,$(BRANCH_NAME)) +RUNTIME := runsc +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) +else +RUNTIME := $(BRANCH_NAME) +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) +endif +RUNTIME_BIN := $(RUNTIME_DIR)/runsc +RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs +RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% + +$(RUNTIME_BIN): # See below. + @mkdir -p "$(RUNTIME_DIR)" + @$(call copy,//runsc,$(RUNTIME_BIN)) +.PHONY: $(RUNTIME_BIN) # Real file, but force rebuild. + +# Configure helpers for below. +configure_noreload = \ + $(call header,CONFIGURE $(1) → $(RUNTIME_BIN) $(2)); \ + sudo $(RUNTIME_BIN) install --experimental=true --runtime="$(1)" -- --debug-log "$(RUNTIME_LOGS)" $(2) && \ + sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)" +reload_docker = \ + sudo systemctl reload docker && \ + if test -f /etc/docker/daemon.json; then \ + sudo chmod 0755 /etc/docker && \ + sudo chmod 0644 /etc/docker/daemon.json; \ + fi +configure = $(call configure_noreload,$(1),$(2)) && $(reload_docker) + +# Helpers for above. Requires $(RUNTIME_BIN) dependency. +install_runtime = $(call configure,$(1),$(2) --TESTONLY-test-name-env=RUNSC_TEST_NAME) +# Don't use cached results, otherwise multiple runs using different runtimes +# may be skipped, if all other inputs are the same. +test_runtime = $(call test,--test_arg=--runtime=$(1) --nocache_test_results $(PARTITIONS) $(2)) + +refresh: $(RUNTIME_BIN) ## Updates the runtime binary. +.PHONY: refresh + +dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo. + @$(call configure_noreload,$(RUNTIME),--net-raw) + @$(call configure_noreload,$(RUNTIME)-d,--net-raw --debug --strace --log-packets) + @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile) + @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2) + @$(call reload_docker) +.PHONY: dev + +nogo: ## Surfaces all nogo findings. + @$(call build,--build_tag_filters nogo //...) + @$(call run,//tools/github $(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo) +.PHONY: nogo + +go: ## Builds the Go branch. + @$(call clean) + @$(call build,//:gopath) + @tools/go_branch.sh + +gazelle: ## Runs gazelle to update WORKSPACE. + @$(call run,//:gazelle update-repos -from_file=go.mod -prune) +.PHONY: gazelle ## ## Canonical build and test targets. @@ -126,23 +176,23 @@ TOTAL_PARTITIONS ?= 1 PARTITIONS := --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS) runsc: ## Builds the runsc binary. - @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc") + @$(call build,-c opt //runsc) .PHONY: runsc debian: ## Builds the debian packages. - @$(call submake,build OPTIONS="-c opt" TARGETS="//debian:debian") + @$(call build,-c opt //debian:debian) .PHONY: debian smoke-tests: ## Runs a simple smoke test after build runsc. - @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true") + @$(call run,//runsc,--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true) .PHONY: smoke-tests fuse-tests: - @$(call submake,test OPTIONS="--test_tag_filters fuse $(PARTITIONS)" TARGETS="test/fuse/...") + @$(call test,--test_tag_filters=fuse $(PARTITIONS) test/fuse/...) .PHONY: fuse-tests unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc. - @$(call submake,test TARGETS="pkg/... runsc/... tools/...") + @$(call test,pkg/... runsc/... tools/...) .PHONY: unit-tests tests: ## Runs all unit tests and syscall tests. @@ -158,101 +208,92 @@ network-tests: ## Run all networking integration tests. network-tests: iptables-tests packetdrill-tests packetimpact-tests .PHONY: network-tests -# Standard integration targets. -INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test - syscall-%-tests: - @$(call submake,test OPTIONS="--test_tag_filters runsc_$* $(PARTITIONS)" TARGETS="test/syscalls/...") + @$(call test,--test_tag_filters=runsc_$* $(PARTITIONS) test/syscalls/...) syscall-native-tests: - @$(call submake,test OPTIONS="--test_tag_filters native $(PARTITIONS)" TARGETS="test/syscalls/...") + @$(call test,--test_tag_filters=native $(PARTITIONS) test/syscalls/...) .PHONY: syscall-native-tests syscall-tests: ## Run all system call tests. - @$(call submake,test OPTIONS="$(PARTITIONS)" TARGETS="test/syscalls/...") + @$(call test,$(PARTITIONS) test/syscalls/...) -%-runtime-tests: load-runtimes_% - @$(call submake,install-runtime) - @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") +%-runtime-tests: load-runtimes_% $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),) # Ensure flags are cleared. + @$(call test_runtime,$(RUNTIME),--test_timeout=10800 //test/runtimes:$*) -%-runtime-tests_vfs2: load-runtimes_% - @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") +%-runtime-tests_vfs2: load-runtimes_% $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),--vfs2) + @$(call test_runtime,$(RUNTIME),--test_timeout=10800 //test/runtimes:$*) -do-tests: runsc - @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") - @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true") - @$(call submake,sudo TARGETS="//runsc" ARGS="do true") +do-tests: + @$(call run,//runsc,--rootless do true) + @$(call run,//runsc,--rootless -network=none do true) + @$(call sudo,//runsc,do true) .PHONY: do-tests simple-tests: unit-tests # Compatibility target. .PHONY: simple-tests -docker-tests: load-basic-images - @$(call submake,install-runtime RUNTIME="vfs1") - @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)") - @$(call submake,install-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)") +# Standard integration targets. +INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test + +docker-tests: load-basic $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),) # Clear flags. + @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS)) + @$(call install_runtime,$(RUNTIME),--vfs2) + @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS)) .PHONY: docker-tests -overlay-tests: load-basic-images - @$(call submake,install-runtime RUNTIME="overlay" ARGS="--overlay") - @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)") +overlay-tests: load-basic $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),--overlay) + @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS)) .PHONY: overlay-tests -swgso-tests: load-basic-images - @$(call submake,install-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false") - @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)") +swgso-tests: load-basic $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),--software-gso=true --gso=false) + @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS)) .PHONY: swgso-tests -hostnet-tests: load-basic-images - @$(call submake,install-runtime RUNTIME="hostnet" ARGS="--network=host") - @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false --test_arg=-hostnet=true" TARGETS="$(INTEGRATION_TARGETS)") +hostnet-tests: load-basic $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),--network=host) + @$(call test_runtime,$(RUNTIME),--test_arg=-checkpoint=false --test_arg=-hostnet=true $(INTEGRATION_TARGETS)) .PHONY: hostnet-tests -kvm-tests: load-basic-images +kvm-tests: load-basic $(RUNTIME_BIN) @(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm - @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi - @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test") - @$(call submake,install-runtime RUNTIME="kvm" ARGS="--platform=kvm") - @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)") + @if ! test -w /dev/kvm; then sudo chmod a+rw /dev/kvm; fi + @$(call test,//pkg/sentry/platform/kvm:kvm_test) + @$(call install_runtime,$(RUNTIME),--platform=kvm) + @$(call test_runtime,$(RUNTIME),$(INTEGRATION_TARGETS)) .PHONY: kvm-tests -iptables-tests: load-iptables +iptables-tests: load-iptables $(RUNTIME_BIN) @sudo modprobe iptable_filter @sudo modprobe ip6table_filter - @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test") - @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw") - @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") + @$(call test,--test_arg=-runtime=runc $(PARTITIONS) //test/iptables:iptables_test) + @$(call install_runtime,$(RUNTIME),--net-raw) + @$(call test_runtime,$(RUNTIME),//test/iptables:iptables_test) .PHONY: iptables-tests -# Run the iptables tests with runsc only. Useful for developing to skip runc -# testing. -iptables-runsc-tests: load-iptables - @sudo modprobe iptable_filter - @sudo modprobe ip6table_filter - @$(call submake,install-runtime RUNTIME="iptables" ARGS="--net-raw") - @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") -.PHONY: iptables-runsc-tests - -packetdrill-tests: load-packetdrill - @$(call submake,install-runtime RUNTIME="packetdrill") - @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) -s query TARGETS='attr(tags, packetdrill, tests(//...))')") +packetdrill-tests: load-packetdrill $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),) # Clear flags. + @$(call test_runtime,$(RUNTIME),//test/packetdrill:all_tests) .PHONY: packetdrill-tests -packetimpact-tests: load-packetimpact +packetimpact-tests: load-packetimpact $(RUNTIME_BIN) @sudo modprobe iptable_filter @sudo modprobe ip6table_filter - @$(call submake,install-runtime RUNTIME="packetimpact") - @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) -s query TARGETS='attr(tags, packetimpact, tests(//...))')") + @$(call install_runtime,$(RUNTIME),) # Clear flags. + @$(call test_runtime,$(RUNTIME),--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3 //test/packetimpact/tests:all_tests) .PHONY: packetimpact-tests # Specific containerd version tests. -containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu - @$(call submake,install-runtime RUNTIME="root") - @CONTAINERD_VERSION=$* $(MAKE) -s sudo TARGETS="tools/installers:containerd" - @$(MAKE) -s sudo TARGETS="tools/installers:shim" - @$(MAKE) -s sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v" +containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu $(RUNTIME_BIN) + @$(call install_runtime,$(RUNTIME),) # Clear flags. + @$(call sudo,tools/installers:containerd,$*) + @$(call sudo,tools/installers:shim) + @$(call sudo,test/root:root_test,--runtime=$(RUNTIME) -test.v) # Note that we can't run containerd-test-1.1.8 tests here. # @@ -270,53 +311,53 @@ containerd-tests: containerd-test-1.4.3 ## Targets to run benchmarks. See //test/benchmarks for details. ## ## common arguments: -## RUNTIME_ARGS - arguments to runsc placed in /etc/docker/daemon.json -## e.g. "--platform=ptrace" -## BENCHMARKS_PROJECT - BigQuery project to which to send data. -## BENCHMARKS_DATASET - BigQuery dataset to which to send data. -## BENCHMARKS_TABLE - BigQuery table to which to send data. -## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go. -## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run. -## BENCHMARKS_OFFICIAL - marks the data as official. +## BENCHMARKS_PROJECT - BigQuery project to which to send data. +## BENCHMARKS_DATASET - BigQuery dataset to which to send data. +## BENCHMARKS_TABLE - BigQuery table to which to send data. +## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go. +## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run. +## BENCHMARKS_OFFICIAL - marks the data as official. ## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm). +## BENCHMARKS_FILTER - filter to be applied to the test suite. +## BENCHMARKS_OPTIONS - options to be passed to the test. ## -BENCHMARKS_PROJECT := gvisor-benchmarks -BENCHMARKS_DATASET := kokoro -BENCHMARKS_TABLE := benchmarks -BENCHMARKS_SUITE := start -BENCHMARKS_UPLOAD := false -BENCHMARKS_OFFICIAL := false -BENCHMARKS_PLATFORMS := ptrace -BENCHMARKS_TARGETS := //test/benchmarks/base:startup_test -BENCHMARKS_ARGS := -test.bench=. -pprof-cpu -pprof-heap -pprof-heap -pprof-block - -init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema -## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop. - $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \ - --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)") +BENCHMARKS_PROJECT ?= gvisor-benchmarks +BENCHMARKS_DATASET ?= kokoro +BENCHMARKS_TABLE ?= benchmarks +BENCHMARKS_SUITE ?= ffmpeg +BENCHMARKS_UPLOAD ?= false +BENCHMARKS_OFFICIAL ?= false +BENCHMARKS_PLATFORMS ?= ptrace +BENCHMARKS_TARGETS := //test/benchmarks/media:ffmpeg_test +BENCHMARKS_FILTER := . +BENCHMARKS_OPTIONS := -test.benchtime=10s +BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex $(BENCHMARKS_OPTIONS) + +init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. + @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)) .PHONY: init-benchmark-table -benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. - $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \ - $(call submake,run-benchmark RUNTIME="$(PLATFORM)" ARGS="--platform=$(PLATFORM) --vfs2") && \ - $(call submake,run-benchmark RUNTIME="$(PLATFORM)_vfs1" ARGS="--platform=$(PLATFORM)") && \ - ) \ - $(call submake, run-benchmark RUNTIME="runc") +# $(1) is the runtime name, $(2) are the arguments. +run_benchmark = \ + ($(call header,BENCHMARK $(1) $(2)); \ + set -euo pipefail; \ + if test "$(1)" != "runc"; then $(call install_runtime,$(1),--profile $(2)); fi; \ + export T=$$(mktemp --tmpdir logs.$(1).XXXXXX); \ + $(call sudo,$(BENCHMARKS_TARGETS),-runtime=$(1) $(BENCHMARKS_ARGS)) | tee $$T; \ + if test "$(BENCHMARKS_UPLOAD)" = "true"; then \ + $(call run,tools/parsers:parser,parse --debug --file=$$T --runtime=$(1) --suite_name=$(BENCHMARKS_SUITE) --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)); \ + fi; \ + rm -rf $$T) + +benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. + @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \ + $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) --vfs2) && \ + ) true + @$(call run-benchmark,runc) .PHONY: benchmark-platforms -run-benchmark: load-benchmarks-images ## Runs single benchmark and optionally sends data to BigQuery. - @if [[ "$(RUNTIME)" != "runc" ]]; then $(call submake,install-runtime ARGS="$(ARGS) --profile"); fi - @T=$$(mktemp --tmpdir logs.$(RUNTIME).XXXXXX); \ - $(call submake,sudo TARGETS="$(BENCHMARKS_TARGETS)" ARGS="--runtime=$(RUNTIME) $(BENCHMARKS_ARGS) | tee $$T"); \ - rc=$$?; \ - if [[ $$rc -eq 0 ]] && [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \ - $(call submake,run TARGETS="tools/parsers:parser" ARGS="parse --debug --file=$$T \ - --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \ - --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \ - --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \ - fi; \ - rm -rf $$T; \ - exit $$rc +run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery. + @$(call run_benchmark,$(RUNTIME),) .PHONY: run-benchmark ## @@ -336,7 +377,7 @@ WEBSITE_PROJECT := gvisordev WEBSITE_REGION := us-central1 website-build: load-jekyll ## Build the site image locally. - @$(call submake,run TARGETS="//website:website" ARGS="$(WEBSITE_IMAGE)") + @$(call run,//website:website,$(WEBSITE_IMAGE)) .PHONY: website-build website-server: website-build ## Run a local server for development. @@ -362,17 +403,17 @@ website-deploy: website-push ## Deploy a new version of the website. ## RELEASE_NAME - The name of the release in the proper format (needed for tag). ## RELEASE_NOTES - The file containing release notes (needed for tag). ## -RELEASE_ROOT := $(CURDIR)/repo -RELEASE_KEY := repo.key -RELEASE_NIGHTLY := false -RELEASE_COMMIT := -RELEASE_NAME := -RELEASE_NOTES := - +RELEASE_ROOT := $(CURDIR)/repo +RELEASE_KEY := repo.key +RELEASE_NIGHTLY := false +RELEASE_COMMIT := +RELEASE_NAME := +RELEASE_NOTES := GPG_TEST_OPTIONS := $(shell if gpg --pinentry-mode loopback --version >/dev/null 2>&1; then echo --pinentry-mode loopback; fi) + $(RELEASE_KEY): @echo "WARNING: Generating a key for testing ($@); don't use this." - T=$$(mktemp --tmpdir keyring.XXXXXX); \ + @T=$$(mktemp --tmpdir keyring.XXXXXX); \ C=$$(mktemp --tmpdir config.XXXXXX); \ echo Key-Type: DSA >> $$C && \ echo Key-Length: 1024 >> $$C && \ @@ -386,11 +427,11 @@ $(RELEASE_KEY): release: $(RELEASE_KEY) ## Builds a release. @mkdir -p $(RELEASE_ROOT) - @T=$$(mktemp -d --tmpdir release.XXXXXX); \ - $(call submake,copy TARGETS="//runsc:runsc" DESTINATION=$$T) && \ - $(call submake,copy TARGETS="//shim/v1:gvisor-containerd-shim" DESTINATION=$$T) && \ - $(call submake,copy TARGETS="//shim/v2:containerd-shim-runsc-v1" DESTINATION=$$T) && \ - $(call submake,copy TARGETS="//debian:debian" DESTINATION=$$T) && \ + @export T=$$(mktemp -d --tmpdir release.XXXXXX); \ + $(call copy,//runsc:runsc,$$T) && \ + $(call copy,//shim/v1:gvisor-containerd-shim,$$T) && \ + $(call copy,//shim/v2:containerd-shim-runsc-v1,$$T) && \ + $(call copy,//debian:debian,$$T) && \ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \ rc=$$?; rm -rf $$T; exit $$rc .PHONY: release @@ -398,74 +439,3 @@ release: $(RELEASE_KEY) ## Builds a release. tag: ## Creates and pushes a release tag. @tools/tag_release.sh "$(RELEASE_COMMIT)" "$(RELEASE_NAME)" "$(RELEASE_NOTES)" .PHONY: tag - -## -## Development helpers and tooling. -## -## These targets faciliate local development by automatically -## installing and configuring a runtime. Several variables may -## be used here to tweak the installation: -## RUNTIME - The name of the installed runtime (default: branch). -## RUNTIME_DIR - Where the runtime will be installed (default: temporary directory with the $RUNTIME). -## RUNTIME_BIN - The runtime binary (default: $RUNTIME_DIR/runsc). -## RUNTIME_LOG_DIR - The logs directory (default: $RUNTIME_DIR/logs). -## RUNTIME_LOGS - The log pattern (default: $RUNTIME_LOG_DIR/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%). -## -ifeq (,$(BRANCH_NAME)) -RUNTIME := runsc -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) -else -RUNTIME := $(BRANCH_NAME) -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) -endif -RUNTIME_BIN := $(RUNTIME_DIR)/runsc -RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs -RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% - -dev: ## Installs a set of local runtimes. Requires sudo. - @$(call submake,refresh) - @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw") - @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets") - @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile") - @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2") - @sudo systemctl restart docker -.PHONY: dev - -refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-runtime' first. - @mkdir -p "$(RUNTIME_DIR)" - @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)") -.PHONY: refresh - -install-runtime: ## Installs the runtime for testing. Requires sudo. - @$(call submake,refresh) - @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="$(ARGS) --TESTONLY-test-name-env=RUNSC_TEST_NAME") - @sudo systemctl restart docker - @if [[ -f /etc/docker/daemon.json ]]; then \ - sudo chmod 0755 /etc/docker && \ - sudo chmod 0644 /etc/docker/daemon.json; \ - fi -.PHONY: install-runtime - -install-debug-runtime: ## Installs the runtime for debugging. Requires sudo. - @$(call submake,install-runtime ARGS="--debug --strace --log-packets $(ARGS)") -.PHONY: install-debug-runtime - -configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-runtime. - @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) - @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" - @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)" - @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)" -.PHONY: configure - -test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided. - @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME) $(PARTITIONS)") -.PHONY: test-runtime - -nogo: ## Surfaces all nogo findings. - @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...") - @$(call submake,run TARGETS="//tools/github" ARGS="$(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo") -.PHONY: nogo - -gazelle: ## Runs gazelle to update WORKSPACE. - @$(call submake,run TARGETS="//:gazelle" ARGS="update-repos -from_file=go.mod -prune") -.PHONY: gazelle @@ -1,6 +1,6 @@ ![gVisor](g3doc/logo.png) -![](https://github.com/google/gvisor/workflows/Build/badge.svg) +[![Build status](https://badge.buildkite.com/3b159f20b9830461a71112566c4171c0bdfd2f980a8e4c0ae6.svg?branch=master)](https://buildkite.com/gvisor/pipeline) [![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community) [![code search](https://img.shields.io/badge/code-search-blue)](https://cs.opensource.google/gvisor/gvisor) @@ -1,4 +1,4 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive", "http_file") load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # Bazel/starlark utilities. @@ -176,6 +176,19 @@ http_archive( ], ) +# Schemas for testing. +http_file( + name = "buildkite_pipeline_schema", + sha256 = "3369c58038b4d55c08928affafb653716eb1e7b3cabb4a391aef979dd921f4e1", + urls = ["https://raw.githubusercontent.com/buildkite/pipeline-schema/f7a0894074d194bcf19eec5411fec0528f7f4180/schema.json"], +) + +http_file( + name = "github_workflow_schema", + sha256 = "2c375bb43dbc8b32b1bed46c290d0b70a8fa2aca7a5484dfca1b6e9c38cf9e7a", + urls = ["https://raw.githubusercontent.com/SchemaStore/schemastore/27612065234778feaac216ce14dd47846fe0a2dd/src/schemas/json/github-workflow.json"], +) + # External Go repositories. # # Unfortunately, gazelle will automatically parse go modules in the @@ -524,8 +537,8 @@ go_repository( name = "com_github_containerd_cgroups", build_file_proto_mode = "disable", importpath = "github.com/containerd/cgroups", - sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=", - version = "v0.0.0-20181219155423-39b18af02c41", + sum = "h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw=", + version = "v0.0.0-20201119153540-4cbc285b3327", ) go_repository( @@ -1391,3 +1404,24 @@ go_repository( sum = "h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=", version = "v0.0.0-20190801114015-581e00157fb1", ) + +go_repository( + name = "com_github_xeipuuv_gojsonpointer", + importpath = "github.com/xeipuuv/gojsonpointer", + sum = "h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=", + version = "v0.0.0-20190905194746-02993c407bfb", +) + +go_repository( + name = "com_github_xeipuuv_gojsonreference", + importpath = "github.com/xeipuuv/gojsonreference", + sum = "h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0=", + version = "v0.0.0-20180127040603-bd5ef7bd5415", +) + +go_repository( + name = "com_github_xeipuuv_gojsonschema", + importpath = "github.com/xeipuuv/gojsonschema", + sum = "h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74=", + version = "v1.2.0", +) diff --git a/g3doc/proposals/runtime_dedicate_os_thread.md b/g3doc/proposals/runtime_dedicate_os_thread.md new file mode 100644 index 000000000..dc70055b0 --- /dev/null +++ b/g3doc/proposals/runtime_dedicate_os_thread.md @@ -0,0 +1,188 @@ +# `runtime.DedicateOSThread` + +Status as of 2020-09-18: Deprioritized; initial studies in #2180 suggest that +this may be difficult to support in the Go runtime due to issues with GC. + +## Summary + +Allow goroutines to bind to kernel threads in a way that allows their scheduling +to be kernel-managed rather than runtime-managed. + +## Objectives + +* Reduce Go runtime overhead in the gVisor sentry (#2184). + +* Minimize intrusiveness of changes to the Go runtime. + +## Background + +In Go, execution contexts are referred to as goroutines, which the runtime calls +Gs. The Go runtime maintains a variably-sized pool of threads (called Ms by the +runtime) on which Gs are executed, as well as a pool of "virtual processors" +(called Ps by the runtime) of size equal to `runtime.GOMAXPROCS()`. Usually, +each M requires a P in order to execute Gs, limiting the number of concurrently +executing goroutines to `runtime.GOMAXPROCS()`. + +The `runtime.LockOSThread` function temporarily locks the invoking goroutine to +its current thread. It is primarily useful for interacting with OS or non-Go +library facilities that are per-thread. It does not reduce interactions with the +Go runtime scheduler: locked Ms relinquish their P when they become blocked, and +only continue execution after another M "chooses" their locked G to run and +donates their P to the locked M instead. + +## Problems + +### Context Switch Overhead + +Most goroutines in the gVisor sentry are task goroutines, which back application +threads. Task goroutines spend large amounts of time blocked on syscalls that +execute untrusted application code. When invoking said syscall (which varies by +gVisor platform), the task goroutine may interact with the Go runtime in one of +three ways: + +* It can invoke the syscall without informing the runtime. In this case, the + task goroutine will continue to hold its P during the syscall, limiting the + number of application threads that can run concurrently to + `runtime.GOMAXPROCS()`. This is problematic because the Go runtime scheduler + is known to scale poorly with `GOMAXPROCS`; see #1942 and + https://github.com/golang/go/issues/28808. It also means that preemption of + application threads must be driven by sentry or runtime code, which is + strictly slower than kernel-driven preemption (since the sentry must invoke + another syscall to preempt the application thread). + +* It can call `runtime.entersyscallblock` before invoking the syscall, and + `runtime.exitsyscall` after the syscall returns. In this case, the task + goroutine will release its P while the syscall is executing. This allows the + number of threads concurrently executing application code to exceed + `GOMAXPROCS`. However, this incurs additional latency on syscall entry (to + hand off the released P to another M, often requiring a `futex(FUTEX_WAKE)` + syscall) and on syscall exit (to acquire a new P). It also drastically + increases the number of threads that concurrently interact with the runtime + scheduler, which is also problematic for performance (both in terms of CPU + utilization and in terms of context switch latency); see #205. + +- It can call `runtime.entersyscall` before invoking the syscall, and + `runtime.exitsyscall` after the syscall returns. In this case, the task + goroutine "lazily releases" its P, allowing the runtime's "sysmon" thread to + steal it on behalf of another M after a 20us delay. This mitigates the + context switch latency problem when there are few task goroutines and the + interval between switches to application code (i.e. the interval between + application syscalls, page faults, or signal delivery) is short. (Cynically, + this means that it's most effective in microbenchmarks). However, the delay + before a P is stolen can also be problematic for performance when there are + both many task goroutines switching to application code (lazily releasing + their Ps) *and* many task goroutines switching to sentry code (contending + for Ps), which is likely in larger heterogeneous workloads. + +### Blocking Overhead + +Task goroutines block on behalf of application syscalls like `futex` and +`epoll_wait` by receiving from a Go channel. (Future work may convert task +goroutine blocking to use the `syncevent` package to avoid overhead associated +with channels and `select`, but this does not change how blocking interacts with +the Go runtime scheduler.) + +If `runtime.LockOSThread()` is not in effect when a task goroutine blocks, then +when the task goroutine is unblocked (by e.g. an application `FUTEX_WAKE`, +signal delivery, or a timeout) by sending to the blocked channel, +`runtime.ready` migrates the unblocked G to the unblocking P. In most cases, +this implies that every application thread block/unblock cycle results in a +migration of the thread between Ps, and therefore Ms, and therefore cores, +resulting in reduced application performance due to loss of CPU caches. +Furthermore, in most cases, the unblocking P cannot immediately switch to the +unblocked G (instead resuming execution of its current application thread after +completing the application's `futex(FUTEX_WAKE)`, `tgkill`, etc. syscall), often +requiring that another P steal the unblocked G before it can resume execution. + +If `runtime.LockOSThread()` is in effect when a task goroutine blocks, then the +G will remain locked to its M, avoiding the core migration described above; +however, wakeup latency is significantly increased since, as described in +"Background", the G still needs to be selected by the scheduler before it can +run, and the M that selects the G then needs to transfer its P to the locked M, +incurring an additional `FUTEX_WAKE` syscall and round of kernel scheduling. + +## Proposal + +We propose to add a function, tentatively called `DedicateOSThread`, to the Go +`runtime` package, documented as follows: + +```go +// DedicateOSThread wires the calling goroutine to its current operating system +// thread, and exempts it from counting against GOMAXPROCS. The calling +// goroutine will always execute in that thread, and no other goroutine will +// execute in it, until the calling goroutine has made as many calls to +// UndedicateOSThread as to DedicateOSThread. If the calling goroutine exits +// without unlocking the thread, the thread will be terminated. +// +// DedicateOSThread should only be used by long-lived goroutines that usually +// block due to blocking system calls, rather than interaction with other +// goroutines. +func DedicateOSThread() +``` + +Mechanically, `DedicateOSThread` implies `LockOSThread` (i.e. it locks the +invoking G to a M), but additionally locks the invoking M to a P. Ps locked by +`DedicateOSThread` are not counted against `GOMAXPROCS`; that is, the actual +number of Ps in the system (`len(runtime.allp)`) is `GOMAXPROCS` plus the number +of bound Ps (plus some slack to avoid frequent changes to `runtime.allp`). +Corollaries: + +* If `runtime.ready` observes that a readied G is locked to a M locked to a P, + it immediately wakes the locked M without migrating the G to the readying P + or waiting for a future call to `runtime.schedule` to select the readied G + in `runtime.findrunnable`. + +* `runtime.stoplockedm` and `runtime.reentersyscall` skip the release of + locked Ps; the latter also skips sysmon wakeup. `runtime.stoplockedm` and + `runtime.exitsyscall` skip re-acquisition of Ps if one is locked. + +* sysmon does not attempt to preempt Gs that are locked to Ps, avoiding + fruitless overhead from `tgkill` syscalls and signal delivery. + +* `runtime.findrunnable`'s work stealing skips locked Ps (suggesting that + unlocked Ps be tracked in a separate array). `runtime.findrunnable` on + locked Ps skip the global run queue, work stealing, and possibly netpoll. + +* New goroutines created by goroutines with locked Ps are enqueued on the + global run queue rather than the invoking P's local run queue. + +While gVisor's use case does not strictly require that the association is +reversible (with `runtime.UndedicateOSThread`), such a feature is required to +allow reuse of locked Ms, which is likely to be critical for performance. + +## Alternatives Considered + +* Make the runtime scale well with `GOMAXPROCS`. While we are also + concurrently investigating this problem, this would not address the issues + of increased preemption cost or blocking overhead. + +* Make the runtime scale well with number of Ms. It is unclear if this is + actually feasible, and would not address blocking overhead. + +* Make P-locking part of `LockOSThread`'s behavior. This would likely + introduce performance regressions in existing uses of `LockOSThread` that do + not fit this usage pattern. In particular, since `DedicateOSThread` + transitions the invoker's P from "counted against `GOMAXPROCS`" to "not + counted against `GOMAXPROCS`", it may need to wake another M to run a new P + (that is counted against `GOMAXPROCS`), and the converse applies to + `UndedicateOSThread`. + +* Rewrite the gVisor sentry in a language that does not force userspace + scheduling. This is a last resort due to the amount of code involved. + +## Related Issues + +The proposed functionality is directly analogous to `spawn_blocking` in Rust +async runtimes +[`async_std`](https://docs.rs/async-std/1.8.0/async_std/task/fn.spawn_blocking.html) +and [`tokio`](https://docs.rs/tokio/0.3.5/tokio/task/fn.spawn_blocking.html). + +Outside of gVisor: + +* https://github.com/golang/go/issues/21827#issuecomment-595152452 describes a + use case for this feature in go-delve, where the goroutine that would use + this feature spends much of its time blocked in `ptrace` syscalls. + +* This feature may improve performance in the use case described in + https://github.com/golang/go/issues/18237, given the prominence of + syscall.Syscall in the profile given in that bug report. @@ -10,7 +10,7 @@ require ( github.com/Microsoft/hcsshim v0.8.6 // indirect github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect - github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect + github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 github.com/containerd/containerd v1.3.9 // indirect github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect @@ -51,6 +51,8 @@ github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssN github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI= github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59 h1:qWj4qVYZ95vLWwqyNJCQg7rDsG5wPdze0UaPolH7DUk= github.com/containerd/cgroups v0.0.0-20200531161412-0dbf7f05ba59/go.mod h1:pA0z1pT8KYB3TCXK/ocprsh7MAkoW8bZVzPdih9snmM= +github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327 h1:7grrpcfCtbZLsjtB0DgMuzs1umsJmpzaHMZ6cO6iAWw= +github.com/containerd/cgroups v0.0.0-20201119153540-4cbc285b3327/go.mod h1:ZJeTFisyysqgcCdecO57Dj79RfL0LNeGiFUqLYQRYLE= github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc= github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE= diff --git a/images/BUILD b/images/BUILD index a50f388e9..34b950644 100644 --- a/images/BUILD +++ b/images/BUILD @@ -1,11 +1 @@ package(licenses = ["notice"]) - -# The images filegroup is definitely not a hermetic target, and requires Make -# to do anything meaningful with. However, this will be slurped up and used by -# the tools/installer/images.sh installer, which will ensure that all required -# images are available locally when running vm_tests. -filegroup( - name = "images", - srcs = glob(["**"]), - visibility = ["//tools/installers:__pkg__"], -) diff --git a/images/Makefile b/images/Makefile deleted file mode 100644 index 66aac7802..000000000 --- a/images/Makefile +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/make -f - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ARCH is the architecture used for the build. This may be overriden at the -# command line in order to perform a cross-build (in a limited capacity). -ARCH := $(shell uname -m) - -# Note that the image prefixes used here must match the image mangling in -# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all -# tests are using locally-defined images (that are consistent and idempotent). -REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit -LOCAL_IMAGE_PREFIX ?= gvisor.dev/images -ALL_IMAGES := $(subst /,_,$(subst ./,,$(shell find . -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq))) -ifneq ($(ARCH),$(shell uname -m)) -DOCKER_PLATFORM_ARGS := --platform=$(ARCH) -else -DOCKER_PLATFORM_ARGS := -endif - -list-all-images: - @for image in $(ALL_IMAGES); do echo $${image}; done -.PHONY: list-build-images - -# Handy wrapper to allow load-all-images, push-all-images, etc. -%-all-images: - @$(MAKE) -s $(patsubst %,$*-%,$(ALL_IMAGES)) -load-all-images: - @$(MAKE) -s $(patsubst %,load-%,$(ALL_IMAGES)) - -# Handy wrapper to load specified "groups", e.g. load-basic-images, etc. -load-%-images: - @$(MAKE) -s $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;)))) - -# tag is a function that returns the tag name, given an image. -# -# The tag constructed is used to memoize the image generated (see README.md). -# This scheme is used to enable aggressive caching in a central repository, but -# ensuring that images will always be sourced using the local files if there -# are changes. -path = $(subst _,/,$(1)) -dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi) -tag = $(shell find $(call path,$(1)) -type f -print | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16) -remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH):$(call tag,$(1)) -local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) - -# rebuild builds the image locally. Only the "remote" tag will be applied. Note -# we need to explicitly repull the base layer in order to ensure that the -# architecture is correct. Note that we use the term "rebuild" here to avoid -# conflicting with the bazel "build" terminology, which is used elsewhere. -rebuild-%: FROM=$(shell grep FROM "$(call path,$*)/$(call dockerfile,$*)" | cut -d' ' -f2) -rebuild-%: register-cross - @if ! [ -f "$(call path,$*)/$(call dockerfile,$*)" ]; then \ - (echo "ERROR: Dockerfile for $* not found (is it available for $(ARCH)?)." >&2 && exit 1); \ - fi - $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \ - T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \ - docker build $(DOCKER_PLATFORM_ARGS) \ - -f "$$T/$(call dockerfile,$*)" \ - -t "$(call remote_image,$*)" \ - $$T && \ - rm -rf $$T - -# pull will check the "remote" image and pull if necessary. If the remote image -# must be pulled, then it will tag with the latest local target. Note that pull -# may fail if the remote image is not available. -pull-%: - docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*) - -# load will either pull the "remote" or build it locally. This is the preferred -# entrypoint, as it should never fail. The local tag should always be set after -# this returns (either by the pull or the build). -load-%: - $(MAKE) -s pull-$* || $(MAKE) -s rebuild-$* - docker tag $(call remote_image,$*) $(call local_image,$*) - -# push pushes the remote image, after either pulling (to validate that the tag -# already exists) or building manually. -push-%: load-% - docker push $(call remote_image,$*) - -# register-cross registers the necessary qemu binaries for cross-compilation. -# This may be used by any target that may execute containers that are not the -# native format. -register-cross: -ifneq ($(ARCH),$(shell uname -m)) -ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*)) - docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes -else - @true # Already registered. -endif -else - @true # No cross required. -endif -.PHONY: register-cross diff --git a/images/benchmarks/absl/Dockerfile b/images/benchmarks/absl/Dockerfile.x86_64 index b0dd97695..810c9ef5e 100644 --- a/images/benchmarks/absl/Dockerfile +++ b/images/benchmarks/absl/Dockerfile.x86_64 @@ -12,6 +12,7 @@ RUN set -x \ unzip \ python3 \ && rm -rf /var/lib/apt/lists/* + RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh RUN ./bazel-0.27.0-installer-linux-x86_64.sh diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile index f586978b6..4b6a0f849 100644 --- a/images/benchmarks/hey/Dockerfile +++ b/images/benchmarks/hey/Dockerfile @@ -1,12 +1,13 @@ -FROM ubuntu:18.04 +FROM golang:1.15 as build +RUN go get github.com/rakyll/hey +WORKDIR /go/src/github.com/rakyll/hey +RUN go mod download +RUN CGO_ENABLED=0 go build -o /hey hey.go +FROM ubuntu:18.04 RUN set -x \ && apt-get update \ && apt-get install -y \ wget \ && rm -rf /var/lib/apt/lists/* - -RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \ - && chmod 777 hey_linux_amd64 \ - && cp hey_linux_amd64 /bin/hey \ - && rm hey_linux_amd64 +COPY --from=build /hey /bin/hey diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile.x86_64 index 6c3aafa57..28ae64816 100644 --- a/images/benchmarks/runsc/Dockerfile +++ b/images/benchmarks/runsc/Dockerfile.x86_64 @@ -14,6 +14,7 @@ RUN set -x \ python3 \ python3-pip \ && rm -rf /var/lib/apt/lists/* + RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh RUN ./bazel-3.4.1-installer-linux-x86_64.sh diff --git a/images/default/Dockerfile b/images/default/Dockerfile index d058b83cb..224469267 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -1,16 +1,20 @@ FROM fedora:31 + # Install bazel. RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils RUN pip install --no-cache-dir pycparser RUN dnf install -y bazel3 -# Install gcloud. + +# Install gcloud. Note that while this is "x86_64", it doesn't actually matter. RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \ - tar zxvf - google-cloud-sdk && \ + tar zxf - google-cloud-sdk && \ google-cloud-sdk/install.sh && \ ln -s /google-cloud-sdk/bin/gcloud /usr/bin/gcloud + # Install Docker client for the website build. RUN dnf config-manager --add-repo https://download.docker.com/linux/fedora/docker-ce.repo RUN dnf install -y docker-ce-cli + WORKDIR /workspace ENTRYPOINT ["/usr/bin/bazel"] diff --git a/images/runtimes/go1.12/Dockerfile b/images/runtimes/go1.12/Dockerfile.x86_64 index cb2944062..cb2944062 100644 --- a/images/runtimes/go1.12/Dockerfile +++ b/images/runtimes/go1.12/Dockerfile.x86_64 @@ -56,123 +56,8 @@ global: - "should not use ALL_CAPS in Go names" - "should not use underscores in Go names" exclude: - # A variety of staticcheck and stylecheck - # rules apply here. These should be fixed - # and removed from here, and the global - # rules should be used sparingly. - - pkg/abi/linux/fuse.go:22 - - pkg/abi/linux/fuse.go:25 - - pkg/abi/linux/socket.go:113 - - pkg/abi/linux/tty.go:73 - - pkg/cpuid/cpuid_x86.go:675 - - pkg/gohacks/gohacks_unsafe.go:33 - - pkg/log/json.go:30 - - pkg/log/log.go:359 - - pkg/metric/metric_test.go:20 - - pkg/p9/p9test/client_test.go:687 - - pkg/p9/transport_test.go:196 - - pkg/pool/pool.go:15 - - pkg/refs/refcounter.go:510 - - pkg/refs/refcounter_test.go:169 - - pkg/safemem/block_unsafe.go:89 - - pkg/seccomp/seccomp.go:82 - - pkg/segment/test/set_functions.go:15 - - pkg/sentry/arch/signal.go:166 - - pkg/sentry/arch/signal.go:171 - - pkg/sentry/control/pprof.go:196 - - pkg/sentry/devices/memdev/full.go:58 - - pkg/sentry/devices/memdev/null.go:59 - - pkg/sentry/devices/memdev/random.go:68 - - pkg/sentry/devices/memdev/zero.go:86 - - pkg/sentry/fdimport/fdimport.go:15 - - pkg/sentry/fs/attr.go:257 - - pkg/sentry/fsbridge/fs.go:116 - - pkg/sentry/fsbridge/vfs.go:124 - - pkg/sentry/fsbridge/vfs.go:70 - - pkg/sentry/fs/copy_up.go:365 - - pkg/sentry/fs/copy_up_test.go:65 - - pkg/sentry/fs/dev/net_tun.go:161 - - pkg/sentry/fs/dev/net_tun.go:63 - - pkg/sentry/fs/dev/null.go:97 - - pkg/sentry/fs/dirent_cache.go:64 - - pkg/sentry/fs/fdpipe/pipe_opener_test.go:366 - - pkg/sentry/fs/file_overlay.go:327 - - pkg/sentry/fs/file_overlay.go:524 - - pkg/sentry/fs/filetest/filetest.go:55 - - pkg/sentry/fs/filetest/filetest.go:60 - - pkg/sentry/fs/fs.go:77 - - pkg/sentry/fs/fsutil/file.go:290 - - pkg/sentry/fs/fsutil/file.go:346 - - pkg/sentry/fs/fsutil/host_file_mapper.go:105 - - pkg/sentry/fs/fsutil/inode_cached.go:676 - - pkg/sentry/fs/fsutil/inode_cached.go:772 - - pkg/sentry/fs/gofer/attr.go:120 - - pkg/sentry/fs/gofer/fifo.go:33 - - pkg/sentry/fs/gofer/inode.go:410 - - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97 - - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92 - - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44 - - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91 - - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93 - - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66 - - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53 - - pkg/sentry/fsimpl/fuse/request_response.go:71 - - pkg/sentry/fsimpl/signalfd/signalfd.go:15 - - pkg/sentry/memmap/memmap.go:103 - - pkg/sentry/memmap/memmap.go:163 - - pkg/sentry/mm/aio_context.go:208 - - pkg/sentry/mm/pma.go:683 - - pkg/sentry/usage/cpu.go:42 - - pkg/shim/runsc/runsc.go:16 - - pkg/shim/runsc/utils.go:16 - - pkg/shim/v1/proc/deleted_state.go:16 - - pkg/shim/v1/proc/exec.go:16 - - pkg/shim/v1/proc/exec_state.go:16 - - pkg/shim/v1/proc/init.go:16 - - pkg/shim/v1/proc/init_state.go:16 - - pkg/shim/v1/proc/io.go:16 - - pkg/shim/v1/proc/process.go:16 - - pkg/shim/v1/proc/types.go:16 - - pkg/shim/v1/proc/utils.go:16 - - pkg/shim/v1/shim/api.go:16 - - pkg/shim/v1/shim/platform.go:16 - - pkg/shim/v1/shim/service.go:16 - - pkg/shim/v1/utils/annotations.go:15 - - pkg/shim/v1/utils/utils.go:15 - - pkg/shim/v1/utils/volumes.go:15 - - pkg/shim/v2/api.go:16 - - pkg/shim/v2/epoll.go:18 - - pkg/shim/v2/options/options.go:15 - - pkg/shim/v2/options/options.go:24 - - pkg/shim/v2/options/options.go:26 - - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16 - - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all. - - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22 - - pkg/shim/v2/service.go:15 - - pkg/shim/v2/service_linux.go:18 - - pkg/state/tests/integer_test.go:23 - - pkg/state/tests/integer_test.go:28 - - pkg/sync/rwmutex_test.go:105 - - pkg/syserr/host_linux.go:35 - - pkg/usermem/addr.go:34 - - pkg/usermem/usermem.go:171 - - pkg/usermem/usermem.go:170 - - runsc/boot/compat.go:56 - - test/cmd/test_app/fds.go:171 - - test/iptables/filter_output.go:251 - - test/packetimpact/testbench/connections.go:77 - - tools/bigquery/bigquery.go:106 - - tools/checkescape/test1/test1.go:108 - - tools/checkescape/test1/test1.go:122 - - tools/checkescape/test1/test1.go:137 - - tools/checkescape/test1/test1.go:151 - - tools/checkescape/test1/test1.go:170 - - tools/checkescape/test1/test1.go:39 - - tools/checkescape/test1/test1.go:45 - - tools/checkescape/test1/test1.go:50 - - tools/checkescape/test1/test1.go:64 - - tools/checkescape/test1/test1.go:80 - - tools/checkescape/test1/test1.go:94 + # Generated: exempt all. + - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go analyzers: asmdecl: external: # Enabled. @@ -214,6 +99,8 @@ analyzers: printf: external: # Enabled. shift: + generated: # Disabled for generated code; these shifts are well-defined. + exclude: [".*"] external: # Enabled. stringintconv: external: @@ -250,3 +137,22 @@ analyzers: external: # Enabled. checkescape: external: # Enabled. + SA4016: + internal: + exclude: + - pkg/gohacks/gohacks_unsafe.go # x ^ 0 always equals x. + SA2001: + internal: + exclude: + - pkg/sentry/fs/fs.go # Intentional. + - pkg/sentry/fs/gofer/inode.go # Intentional. + - pkg/refs/refcounter_test.go # Intentional. + ST1021: + internal: + suppress: + - "comment on exported type Translation" # Intentional. + - "comment on exported type PinnedRange" # Intentional. + SA5011: + internal: + exclude: + - pkg/sentry/fs/fdpipe/pipe_opener_test.go # False positive. diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index a0654df2f..8fa61d6f7 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -21,6 +21,7 @@ go_library( "epoll_amd64.go", "epoll_arm64.go", "errors.go", + "errqueue.go", "eventfd.go", "exec.go", "fadvise.go", diff --git a/pkg/abi/linux/errqueue.go b/pkg/abi/linux/errqueue.go new file mode 100644 index 000000000..3905d4222 --- /dev/null +++ b/pkg/abi/linux/errqueue.go @@ -0,0 +1,93 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "gvisor.dev/gvisor/pkg/marshal" +) + +// Socket error origin codes as defined in include/uapi/linux/errqueue.h. +const ( + SO_EE_ORIGIN_NONE = 0 + SO_EE_ORIGIN_LOCAL = 1 + SO_EE_ORIGIN_ICMP = 2 + SO_EE_ORIGIN_ICMP6 = 3 +) + +// SockExtendedErr represents struct sock_extended_err in Linux defined in +// include/uapi/linux/errqueue.h. +// +// +marshal +type SockExtendedErr struct { + Errno uint32 + Origin uint8 + Type uint8 + Code uint8 + Pad uint8 + Info uint32 + Data uint32 +} + +// SockErrCMsg represents the IP*_RECVERR control message. +type SockErrCMsg interface { + marshal.Marshallable + + CMsgLevel() uint32 + CMsgType() uint32 +} + +// SockErrCMsgIPv4 is the IP_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv4 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv4/ip_sockglue.c:ip_recv_error(). +// +// +marshal +type SockErrCMsgIPv4 struct { + SockExtendedErr + Offender SockAddrInet +} + +var _ SockErrCMsg = (*SockErrCMsgIPv4)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv4) CMsgLevel() uint32 { + return SOL_IP +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv4) CMsgType() uint32 { + return IP_RECVERR +} + +// SockErrCMsgIPv6 is the IPV6_RECVERR control message used in +// recvmsg(MSG_ERRQUEUE) by ipv6 sockets. This is equilavent to `struct errhdr` +// defined in net/ipv6/datagram.c:ipv6_recv_error(). +// +// +marshal +type SockErrCMsgIPv6 struct { + SockExtendedErr + Offender SockAddrInet6 +} + +var _ SockErrCMsg = (*SockErrCMsgIPv6)(nil) + +// CMsgLevel implements SockErrCMsg.CMsgLevel. +func (*SockErrCMsgIPv6) CMsgLevel() uint32 { + return SOL_IPV6 +} + +// CMsgType implements SockErrCMsg.CMsgType. +func (*SockErrCMsgIPv6) CMsgType() uint32 { + return IPV6_RECVERR +} diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index d91c97a64..1070b457c 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -19,16 +19,22 @@ import ( "gvisor.dev/gvisor/pkg/marshal/primitive" ) +// FUSEOpcode is a FUSE operation code. +// // +marshal type FUSEOpcode uint32 +// FUSEOpID is a FUSE operation ID. +// // +marshal type FUSEOpID uint64 // FUSE_ROOT_ID is the id of root inode. const FUSE_ROOT_ID = 1 -// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +// Opcodes for FUSE operations. +// +// Analogous to the opcodes in include/linux/fuse.h. const ( FUSE_LOOKUP FUSEOpcode = 1 FUSE_FORGET = 2 /* no reply */ diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 0adff8dff..2424884c1 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -43,10 +43,10 @@ const ( SEMVMX = 32767 SEMAEM = SEMVMX - // followings are unused in kernel SEMUME = SEMOPM SEMMNU = SEMMNS SEMMAP = SEMMNS + SEMUSZ = 20 ) const SEM_UNDO = 0x1000 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index d156d41e4..556892dc3 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -111,12 +111,12 @@ type SockType int // Socket types, from linux/net.h. const ( SOCK_STREAM SockType = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_DGRAM SockType = 2 + SOCK_RAW SockType = 3 + SOCK_RDM SockType = 4 + SOCK_SEQPACKET SockType = 5 + SOCK_DCCP SockType = 6 + SOCK_PACKET SockType = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are @@ -448,6 +448,8 @@ type ControlMessageCredentials struct { // A ControlMessageIPPacketInfo is IP_PKTINFO socket control message. // // ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h. +// +// +stateify savable type ControlMessageIPPacketInfo struct { NIC int32 LocalAddr InetAddr diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go index a4f4b2c5e..fdfe31417 100644 --- a/pkg/coverage/coverage.go +++ b/pkg/coverage/coverage.go @@ -27,6 +27,7 @@ import ( "io" "sort" "sync/atomic" + "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -34,12 +35,6 @@ import ( "github.com/bazelbuild/rules_go/go/tools/coverdata" ) -// KcovAvailable returns whether the kcov coverage interface is available. It is -// available as long as coverage is enabled for some files. -func KcovAvailable() bool { - return len(coverdata.Cover.Blocks) > 0 -} - // coverageMu must be held while accessing coverdata.Cover. This prevents // concurrent reads/writes from multiple threads collecting coverage data. var coverageMu sync.RWMutex @@ -47,6 +42,22 @@ var coverageMu sync.RWMutex // once ensures that globalData is only initialized once. var once sync.Once +// blockBitLength is the number of bits used to represent coverage block index +// in a synthetic PC (the rest are used to represent the file index). Even +// though a PC has 64 bits, we only use the lower 32 bits because some users +// (e.g., syzkaller) may truncate that address to a 32-bit value. +// +// As of this writing, there are ~1200 files that can be instrumented and at +// most ~1200 blocks per file, so 16 bits is more than enough to represent every +// file and every block. +const blockBitLength = 16 + +// KcovAvailable returns whether the kcov coverage interface is available. It is +// available as long as coverage is enabled for some files. +func KcovAvailable() bool { + return len(coverdata.Cover.Blocks) > 0 +} + var globalData struct { // files is the set of covered files sorted by filename. It is calculated at // startup. @@ -104,14 +115,14 @@ var coveragePool = sync.Pool{ // coverage tools, we reset the global coverage data every time this function is // run. func ConsumeCoverageData(w io.Writer) int { - once.Do(initCoverageData) + InitCoverageData() coverageMu.Lock() defer coverageMu.Unlock() total := 0 var pcBuffer [8]byte - for fileIndex, file := range globalData.files { + for fileNum, file := range globalData.files { counters := coverdata.Cover.Counters[file] for index := 0; index < len(counters); index++ { if atomic.LoadUint32(&counters[index]) == 0 { @@ -119,7 +130,7 @@ func ConsumeCoverageData(w io.Writer) int { } // Non-zero coverage data found; consume it and report as a PC. atomic.StoreUint32(&counters[index], 0) - pc := globalData.syntheticPCs[fileIndex][index] + pc := globalData.syntheticPCs[fileNum][index] usermem.ByteOrder.PutUint64(pcBuffer[:], pc) n, err := w.Write(pcBuffer[:]) if err != nil { @@ -142,31 +153,84 @@ func ConsumeCoverageData(w io.Writer) int { return total } -// initCoverageData initializes globalData. It should only be called once, -// before any kcov data is written. -func initCoverageData() { - // First, order all files. Then calculate synthetic PCs for every block - // (using the well-defined ordering for files as well). - for file := range coverdata.Cover.Blocks { - globalData.files = append(globalData.files, file) +// InitCoverageData initializes globalData. It should be called before any kcov +// data is written. +func InitCoverageData() { + once.Do(func() { + // First, order all files. Then calculate synthetic PCs for every block + // (using the well-defined ordering for files as well). + for file := range coverdata.Cover.Blocks { + globalData.files = append(globalData.files, file) + } + sort.Strings(globalData.files) + + for fileNum, file := range globalData.files { + blocks := coverdata.Cover.Blocks[file] + pcs := make([]uint64, 0, len(blocks)) + for blockNum := range blocks { + pcs = append(pcs, calculateSyntheticPC(fileNum, blockNum)) + } + globalData.syntheticPCs = append(globalData.syntheticPCs, pcs) + } + }) +} + +// Symbolize prints information about the block corresponding to pc. +func Symbolize(out io.Writer, pc uint64) error { + fileNum, blockNum := syntheticPCToIndexes(pc) + file, err := fileFromIndex(fileNum) + if err != nil { + return err + } + block, err := blockFromIndex(file, blockNum) + if err != nil { + return err } - sort.Strings(globalData.files) - - // nextSyntheticPC is the first PC that we generate for a block. - // - // This uses a standard-looking kernel range for simplicity. - // - // FIXME(b/160639712): This is only necessary because syzkaller requires - // addresses in the kernel range. If we can remove this constraint, then we - // should be able to use the actual addresses. - var nextSyntheticPC uint64 = 0xffffffff80000000 - for _, file := range globalData.files { - blocks := coverdata.Cover.Blocks[file] - thisFile := make([]uint64, 0, len(blocks)) - for range blocks { - thisFile = append(thisFile, nextSyntheticPC) - nextSyntheticPC++ // Advance. + writeBlock(out, pc, file, block) + return nil +} + +// WriteAllBlocks prints all information about all blocks along with their +// corresponding synthetic PCs. +func WriteAllBlocks(out io.Writer) { + for fileNum, file := range globalData.files { + for blockNum, block := range coverdata.Cover.Blocks[file] { + writeBlock(out, calculateSyntheticPC(fileNum, blockNum), file, block) } - globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile) } } + +func calculateSyntheticPC(fileNum int, blockNum int) uint64 { + return (uint64(fileNum) << blockBitLength) + uint64(blockNum) +} + +func syntheticPCToIndexes(pc uint64) (fileNum int, blockNum int) { + return int(pc >> blockBitLength), int(pc & ((1 << blockBitLength) - 1)) +} + +// fileFromIndex returns the name of the file in the sorted list of instrumented files. +func fileFromIndex(i int) (string, error) { + total := len(globalData.files) + if i < 0 || i >= total { + return "", fmt.Errorf("file index out of range: [%d] with length %d", i, total) + } + return globalData.files[i], nil +} + +// blockFromIndex returns the i-th block in the given file. +func blockFromIndex(file string, i int) (testing.CoverBlock, error) { + blocks, ok := coverdata.Cover.Blocks[file] + if !ok { + return testing.CoverBlock{}, fmt.Errorf("instrumented file %s does not exist", file) + } + total := len(blocks) + if i < 0 || i >= total { + return testing.CoverBlock{}, fmt.Errorf("block index out of range: [%d] with length %d", i, total) + } + return blocks[i], nil +} + +func writeBlock(out io.Writer, pc uint64, file string, block testing.CoverBlock) { + io.WriteString(out, fmt.Sprintf("%#x\n", pc)) + io.WriteString(out, fmt.Sprintf("%s:%d.%d,%d.%d\n", file, block.Line0, block.Col0, block.Line1, block.Col1)) +} diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go index f7f9dbf86..69eeb7528 100644 --- a/pkg/cpuid/cpuid.go +++ b/pkg/cpuid/cpuid.go @@ -36,3 +36,14 @@ package cpuid // On arm64, features are numbered according to the ELF HWCAP definition. // arch/arm64/include/uapi/asm/hwcap.h type Feature int + +// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a +// subset of the host feature set. +type ErrIncompatible struct { + message string +} + +// Error implements error. +func (e ErrIncompatible) Error() string { + return e.message +} diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go index 17a89c00d..392711e8f 100644 --- a/pkg/cpuid/cpuid_x86.go +++ b/pkg/cpuid/cpuid_x86.go @@ -681,17 +681,6 @@ func (fs *FeatureSet) Intel() bool { return fs.VendorID == intelVendorID } -// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a -// subset of the host feature set. -type ErrIncompatible struct { - message string -} - -// Error implements error. -func (e ErrIncompatible) Error() string { - return e.message -} - // CheckHostCompatible returns nil if fs is a subset of the host feature set. func (fs *FeatureSet) CheckHostCompatible() error { hfs := HostFeatureSet() diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD index aa8e4e1f3..cc31d0175 100644 --- a/pkg/flipcall/BUILD +++ b/pkg/flipcall/BUILD @@ -11,7 +11,8 @@ go_library( "futex_linux.go", "io.go", "packet_window_allocator.go", - "packet_window_mmap.go", + "packet_window_mmap_amd64.go", + "packet_window_mmap_arm64.go", ], visibility = ["//visibility:public"], deps = [ diff --git a/pkg/flipcall/packet_window_mmap.go b/pkg/flipcall/packet_window_mmap_amd64.go index 869183b11..869183b11 100644 --- a/pkg/flipcall/packet_window_mmap.go +++ b/pkg/flipcall/packet_window_mmap_amd64.go diff --git a/pkg/flipcall/packet_window_mmap_arm64.go b/pkg/flipcall/packet_window_mmap_arm64.go new file mode 100644 index 000000000..b9c9c44f6 --- /dev/null +++ b/pkg/flipcall/packet_window_mmap_arm64.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package flipcall + +import ( + "syscall" +) + +// Return a memory mapping of the pwd in memory that can be shared outside the sandbox. +func packetWindowMmap(pwd PacketWindowDescriptor) (uintptr, syscall.Errno) { + m, _, err := syscall.RawSyscall6(syscall.SYS_MMAP, 0, uintptr(pwd.Length), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED, uintptr(pwd.FD), uintptr(pwd.Offset)) + return m, err +} diff --git a/pkg/log/json.go b/pkg/log/json.go index bdf9d691e..8c52dcc87 100644 --- a/pkg/log/json.go +++ b/pkg/log/json.go @@ -27,8 +27,8 @@ type jsonLog struct { } // MarshalJSON implements json.Marshaler.MarashalJSON. -func (lv Level) MarshalJSON() ([]byte, error) { - switch lv { +func (l Level) MarshalJSON() ([]byte, error) { + switch l { case Warning: return []byte(`"warning"`), nil case Info: @@ -36,20 +36,20 @@ func (lv Level) MarshalJSON() ([]byte, error) { case Debug: return []byte(`"debug"`), nil default: - return nil, fmt.Errorf("unknown level %v", lv) + return nil, fmt.Errorf("unknown level %v", l) } } // UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON. It can unmarshal // from both string names and integers. -func (lv *Level) UnmarshalJSON(b []byte) error { +func (l *Level) UnmarshalJSON(b []byte) error { switch s := string(b); s { case "0", `"warning"`: - *lv = Warning + *l = Warning case "1", `"info"`: - *lv = Info + *l = Info case "2", `"debug"`: - *lv = Debug + *l = Debug default: return fmt.Errorf("unknown level %q", s) } diff --git a/pkg/log/log.go b/pkg/log/log.go index 37e0605ad..2e3408357 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -356,7 +356,7 @@ func CopyStandardLogTo(l Level) error { case Warning: f = Warningf default: - return fmt.Errorf("Unknown log level %v", l) + return fmt.Errorf("unknown log level %v", l) } stdlog.SetOutput(linewriter.NewWriter(func(p []byte) { diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 6acee90ef..aea7dde38 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -350,9 +350,13 @@ type VerifyParams struct { // For verifyMetadata, params.data is not needed. It only accesses params.tree // for the raw root hash. func verifyMetadata(params *VerifyParams, layout *Layout) error { - root := make([]byte, layout.digestSize) - if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { - return fmt.Errorf("failed to read root hash: %w", err) + var root []byte + // Only read the root hash if we expect that the Merkle tree file is non-empty. + if params.Size != 0 { + root = make([]byte, layout.digestSize) + if _, err := params.Tree.ReadAt(root, layout.blockOffset(layout.rootLevel(), 0 /* index */)); err != nil { + return fmt.Errorf("failed to read root hash: %w", err) + } } descriptor := VerityDescriptor{ Name: params.Name, diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e605b14c..2e3d427ae 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -678,16 +678,15 @@ func renameHelper(h *Harness, root p9.File, srcNames []string, dstNames []string // case. defer checkDeleted(h, dst) } else { + // If the type is different than the destination, then + // we expect the rename to fail. We expect that this + // is returned. + // + // If the file being renamed to itself, this is + // technically allowed and a no-op, but all the + // triggers will fire. if !selfRename { - // If the type is different than the - // destination, then we expect the rename to - // fail. We expect ensure that this is - // returned. expectedErr = syscall.EINVAL - } else { - // This is the file being renamed to itself. - // This is technically allowed and a no-op, but - // all the triggers will fire. } dst.Close() } diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index e7406b374..a29f06ddb 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -197,33 +197,33 @@ func BenchmarkSendRecv(b *testing.B) { for i := 0; i < b.N; i++ { tag, m, err := recv(server, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(1) { - b.Fatalf("got tag %v expected 1", tag) + b.Errorf("got tag %v expected 1", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %T expected *Rflush", m) + b.Errorf("got message %T expected *Rflush", m) } if err := send(server, Tag(2), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } } }() b.ResetTimer() for i := 0; i < b.N; i++ { if err := send(client, Tag(1), &Rflush{}); err != nil { - b.Fatalf("send got err %v expected nil", err) + b.Errorf("send got err %v expected nil", err) } tag, m, err := recv(client, maximumLength, msgRegistry.get) if err != nil { - b.Fatalf("recv got err %v expected nil", err) + b.Errorf("recv got err %v expected nil", err) } if tag != Tag(2) { - b.Fatalf("got tag %v expected 2", tag) + b.Errorf("got tag %v expected 2", tag) } if _, ok := m.(*Rflush); !ok { - b.Fatalf("got message %v expected *Rflush", m) + b.Errorf("got message %v expected *Rflush", m) } } } diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index a1b2e0cfe..54e825b28 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package pool provides a trivial integer pool. package pool import ( diff --git a/pkg/refsvfs2/BUILD b/pkg/refsvfs2/BUILD index bfa1daa10..0377c0876 100644 --- a/pkg/refsvfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -9,7 +9,7 @@ go_template( "refs_template.go", ], opt_consts = [ - "logTrace", + "enableLogging", ], types = [ "T", diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index f64b6c6ae..3fbc91aa5 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -74,11 +74,6 @@ func (r *Refs) LogRefs() bool { return enableLogging } -// EnableLeakCheck enables reference leak checking on r. -func (r *Refs) EnableLeakCheck() { - refsvfs2.Register(r) -} - // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. func (r *Refs) ReadRefs() int64 { @@ -136,7 +131,7 @@ func (r *Refs) TryIncRef() bool { func (r *Refs) DecRef(destroy func()) { v := atomic.AddInt64(&r.refCount, -1) if enableLogging { - refsvfs2.LogDecRef(r, v+1) + refsvfs2.LogDecRef(r, v) } switch { case v < 0: @@ -153,6 +148,6 @@ func (r *Refs) DecRef(destroy func()) { func (r *Refs) afterLoad() { if r.ReadRefs() > 0 { - r.EnableLeakCheck() + refsvfs2.Register(r) } } diff --git a/pkg/safemem/block_unsafe.go b/pkg/safemem/block_unsafe.go index e7fd30743..7857f5853 100644 --- a/pkg/safemem/block_unsafe.go +++ b/pkg/safemem/block_unsafe.go @@ -68,29 +68,29 @@ func blockFromSlice(slice []byte, needSafecopy bool) Block { } } -// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is +// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+length), which is // safe to access without safecopy. // -// Preconditions: ptr+len does not overflow. -func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, false) +// Preconditions: ptr+length does not overflow. +func BlockFromSafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, false) } // BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which // is not safe to access without safecopy. // // Preconditions: ptr+len does not overflow. -func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { - return blockFromPointer(ptr, len, true) +func BlockFromUnsafePointer(ptr unsafe.Pointer, length int) Block { + return blockFromPointer(ptr, length, true) } -func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { - if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { - panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) +func blockFromPointer(ptr unsafe.Pointer, length int, needSafecopy bool) Block { + if uptr := uintptr(ptr); uptr+uintptr(length) < uptr { + panic(fmt.Sprintf("ptr %#x + len %#x overflows", uptr, length)) } return Block{ start: ptr, - length: len, + length: length, needSafecopy: needSafecopy, } } diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 752e2dc32..ec17ebc4d 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -79,7 +79,7 @@ func Install(rules SyscallRules) error { // Perform the actual installation. if errno := SetFilter(instrs); errno != 0 { - return fmt.Errorf("Failed to set filter: %v", errno) + return fmt.Errorf("failed to set filter: %v", errno) } log.Infof("Seccomp filters installed.") diff --git a/pkg/segment/test/set_functions.go b/pkg/segment/test/set_functions.go index 7cd895cc7..652c010da 100644 --- a/pkg/segment/test/set_functions.go +++ b/pkg/segment/test/set_functions.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package segment is a test package. package segment type setFunctions struct{} diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index d75d665ae..dd2effdf9 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } + +// ErrFloatingPoint indicates a failed restore due to unusable floating point +// state. +type ErrFloatingPoint struct { + // supported is the supported floating point state. + supported uint64 + + // saved is the saved floating point state. + saved uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrFloatingPoint) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 19ce99d25..840e53d33 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -17,27 +17,10 @@ package arch import ( - "fmt" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" ) -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} - // XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 5138f3bf5..35d2e07c3 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() { } } -// Pid returns the si_pid field. -func (s *SignalInfo) Pid() int32 { +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) } -// SetPid mutates the si_pid field. -func (s *SignalInfo) SetPid(val int32) { +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } -// Uid returns the si_uid field. -func (s *SignalInfo) Uid() int32 { +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) } -// SetUid mutates the si_uid field. -func (s *SignalInfo) SetUid(val int32) { +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2bf3c45e1..b78e29416 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -15,10 +15,10 @@ package control import ( - "errors" "runtime" "runtime/pprof" "runtime/trace" + "time" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -26,184 +26,253 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) -var errNoOutput = errors.New("no output writer provided") +// Profile includes profile-related RPC stubs. It provides a way to +// control the built-in runtime profiling facilities. +// +// The profile object must be instantied via NewProfile. +type Profile struct { + // kernel is the kernel under profile. It's immutable. + kernel *kernel.Kernel -// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call. -type ProfileOpts struct { - // File is the filesystem path for the profile. - File string `json:"path"` + // cpuMu protects CPU profiling. + cpuMu sync.Mutex - // FilePayload is the destination for the profiling output. - urpc.FilePayload + // blockMu protects block profiling. + blockMu sync.Mutex + + // mutexMu protects mutex profiling. + mutexMu sync.Mutex + + // traceMu protects trace profiling. + traceMu sync.Mutex + + // done is closed when profiling is done. + done chan struct{} } -// Profile includes profile-related RPC stubs. It provides a way to -// control the built-in pprof facility in sentry via sentryctl. -// -// The following options to sentryctl are added: +// NewProfile returns a new Profile object, and a stop callback. // -// - collect CPU profile on-demand. -// sentryctl -pid <pid> pprof-cpu-start -// sentryctl -pid <pid> pprof-cpu-stop -// -// - dump out the stack trace of current go routines. -// sentryctl -pid <pid> pprof-goroutine -type Profile struct { - // Kernel is the kernel under profile. It's immutable. - Kernel *kernel.Kernel +// The stop callback should be used at most once. +func NewProfile(k *kernel.Kernel) (*Profile, func()) { + p := &Profile{ + kernel: k, + done: make(chan struct{}), + } + return p, func() { + close(p.done) + } +} - // mu protects the fields below. - mu sync.Mutex +// CPUProfileOpts contains options specifically for CPU profiles. +type CPUProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload - // cpuFile is the current CPU profile output file. - cpuFile *fd.FD + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` - // traceFile is the current execution trace output file. - traceFile *fd.FD + // Hz is the rate, which may be zero. + Hz int `json:"hz"` } -// StartCPUProfile is an RPC stub which starts recording the CPU profile in a -// file. -func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error { +// CPU is an RPC stub which collects a CPU profile. +func (p *Profile) CPU(o *CPUProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } output, err := fd.NewFromFile(o.FilePayload.Files[0]) if err != nil { return err } + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.cpuMu.Lock() + defer p.cpuMu.Unlock() // Returns an error if profiling is already started. + if o.Hz != 0 { + runtime.SetCPUProfileRate(o.Hz) + } if err := pprof.StartCPUProfile(output); err != nil { - output.Close() return err } + defer pprof.StopCPUProfile() - p.cpuFile = output - return nil -} - -// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the -// profile data. It takes no argument. -func (p *Profile) StopCPUProfile(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.cpuFile == nil { - return errors.New("CPU profiling not started") + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: } - pprof.StopCPUProfile() - p.cpuFile.Close() - p.cpuFile = nil return nil } -// HeapProfile generates a heap profile for the sentry. -func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error { +// HeapProfileOpts contains options specifically for heap profiles. +type HeapProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload +} + +// Heap generates a heap profile. +func (p *Profile) Heap(o *HeapProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() + runtime.GC() // Get up-to-date statistics. - if err := pprof.WriteHeapProfile(output); err != nil { - return err - } - return nil + return pprof.WriteHeapProfile(output) +} + +// GoroutineProfileOpts contains options specifically for goroutine profiles. +type GoroutineProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload } -// GoroutineProfile is an RPC stub which dumps out the stack trace for all -// running goroutines. -func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error { +// Goroutine dumps out the stack trace for all running goroutines. +func (p *Profile) Goroutine(o *GoroutineProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil { - return err - } - return nil + + return pprof.Lookup("goroutine").WriteTo(output, 2) +} + +// BlockProfileOpts contains options specifically for block profiles. +type BlockProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Rate is the block profile rate. + Rate int `json:"rate"` } -// BlockProfile is an RPC stub which dumps out the stack trace that led to -// blocking on synchronization primitives. -func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error { +// Block dumps a blocking profile. +func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("block").WriteTo(output, 0); err != nil { - return err + + p.blockMu.Lock() + defer p.blockMu.Unlock() + + // Always set the rate. We then wait to collect a profile at this rate, + // and disable when we're done. + rate := 1 + if o.Rate != 0 { + rate = o.Rate } - return nil + runtime.SetBlockProfileRate(rate) + defer runtime.SetBlockProfileRate(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("block").WriteTo(output, 0) +} + +// MutexProfileOpts contains options specifically for mutex profiles. +type MutexProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Fraction is the mutex profile fraction. + Fraction int `json:"fraction"` } -// MutexProfile is an RPC stub which dumps out the stack trace of holders of -// contended mutexes. -func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error { +// Mutex dumps a mutex profile. +func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil { - return err + + p.mutexMu.Lock() + defer p.mutexMu.Unlock() + + // Always set the fraction. + fraction := 1 + if o.Fraction != 0 { + fraction = o.Fraction } - return nil + runtime.SetMutexProfileFraction(fraction) + defer runtime.SetMutexProfileFraction(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("mutex").WriteTo(output, 0) } -// StartTrace is an RPC stub which starts collection of an execution trace. -func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { +// TraceProfileOpts contains options specifically for traces. +type TraceProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` +} + +// Trace is an RPC stub which starts collection of an execution trace. +func (p *Profile) Trace(o *TraceProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } output, err := fd.NewFromFile(o.FilePayload.Files[0]) if err != nil { return err } + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.traceMu.Lock() + defer p.traceMu.Unlock() // Returns an error if profiling is already started. if err := trace.Start(output); err != nil { output.Close() return err } + defer trace.Stop() // Ensure all trace contexts are registered. - p.Kernel.RebuildTraceContexts() - - p.traceFile = output - return nil -} - -// StopTrace is an RPC stub which stops collection of an ongoing execution -// trace and flushes the trace data. It takes no argument. -func (p *Profile) StopTrace(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() + p.kernel.RebuildTraceContexts() - if p.traceFile == nil { - return errors.New("Execution tracing not started") + // Wait for the trace. + select { + case <-time.After(o.Duration): + case <-p.done: } // Similarly to the case above, if tasks have not ended traces, we will // lose information. Thus we need to rebuild the tasks in order to have // complete information. This will not lose information if multiple // traces are overlapping. - p.Kernel.RebuildTraceContexts() + p.kernel.RebuildTraceContexts() - trace.Stop() - p.traceFile.Close() - p.traceFile = nil return nil } diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index d800f2c85..62eaca965 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { Callback: func(err error) { if err == nil { log.Infof("Save succeeded: exiting...") + s.Kernel.SetSaveSuccess(false /* autosave */) } else { log.Warningf("Save failed: exiting...") s.Kernel.SetSaveError(err) diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 314661475..badd5b073 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package fdimport provides the Import function. package fdimport import ( diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ff2fe6712..8e0aa9019 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // copyUpBuffers is a buffer pool for copying file content. The buffer // size is the same used by io.Copy. -var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }} +var copyUpBuffers = sync.Pool{ + New: func() interface{} { + b := make([]byte, 8*usermem.PageSize) + return &b + }, +} // copyContentsLocked copies the contents of lower to upper. It panics if // less than size bytes can be copied. @@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. - buf := copyUpBuffers.Get().([]byte) + buf := copyUpBuffers.Get().(*[]byte) defer copyUpBuffers.Put(buf) // Transfer the contents. @@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in // optimizations could be self-defeating. So we leave this as simple as possible. var offset int64 for { - nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset) + nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset) if err != nil && err != io.EOF { return err } @@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in } return nil } - nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset) + nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset) if err != nil { return err } diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index c7a11eec1..e04784db2 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) { wg.Add(1) go func(o *overlayTestFile) { if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) + t.Errorf("failed to copy up: %v", err) } wg.Done() }(file) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 8049538f2..ec3d3f96c 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File { // Read just fails the request. func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") + return 0, fmt.Errorf("TestFileOperations.Read not implemented") } // Write just fails the request. func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") + return 0, fmt.Errorf("TestFileOperations.Write not implemented") } diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index d481baf77..e5579095b 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType { return fs.BlockDevice case pattr.Mode.IsSocket(): return fs.Socket - case pattr.Mode.IsRegular(): - fallthrough default: return fs.RegularFile } diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 9d6fdd08f..e840b6f5e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { switch d.Inode.StableAttr.Type { case fs.Socket: + if i.session().overrides != nil { + return nil, syserror.ENXIO + } return i.getFileSocket(ctx, d, flags) case fs.Pipe: return i.getFilePipe(ctx, d, flags) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index fbfba1b58..2c14aa6d9 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport. // GetFile implements fs.InodeOperations.GetFile. func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + return newFile(ctx, d, flags, i), nil } diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index 29ff004f2..d0c565879 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { // GetFile implements fs.FileOperations.GetFile. func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil + return nil, syserror.ENXIO } // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index e04cd608d..ad4aea282 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + if flags.Write { fsmetric.TmpfsOpensW.Increment() } else if flags.Read { diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 9009ba3c7..4a555bf72 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -200,7 +200,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt } var fd symlinkFD fd.LockFD.Init(&in.locks) - fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) + if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 3af807a21..204d8d143 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -129,6 +129,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EINVAL } fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + if fuseFDGeneric == nil { + return nil, nil, syserror.EINVAL + } defer fuseFDGeneric.DecRef(ctx) fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) if !ok { diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index dc0180812..41d679358 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) src = src[2:] } + _ = src // Remove unused warning. } // SizeBytes is the size of the payload of the FUSE_INIT response. diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 435a21d77..36a3f6810 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -31,6 +31,7 @@ import ( fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag fileDescription: fileDescription{inode: i}, termios: linux.DefaultReplicaTermios, } + if task := kernel.TaskFromContext(ctx); task != nil { + fd.fgProcessGroup = task.ThreadGroup().ProcessGroup() + fd.session = fd.fgProcessGroup.Session() + } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 0ecb592cf..429733c10 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -164,11 +164,11 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // and write ends of a newly-created pipe, as for pipe(2) and pipe2(2). // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { fs := mnt.Filesystem().Impl().(*filesystem) inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, inode) defer d.DecRef(ctx) - return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) + return inode.pipe.ReaderWriterPair(ctx, mnt, d.VFSDentry(), flags) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index a3780b222..75be6129f 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager { // MemoryManager's users count is incremented, and must be decremented by the // caller when it is no longer in use. func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { - if task.ExitState() == kernel.TaskExitDead { - return nil, syserror.ESRCH - } var m *mm.MemoryManager task.WithMuLocked(func(t *kernel.Task) { m = t.MemoryManager() @@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 } m, err := getMMIncRef(fd.inode.task) if err != nil { - return 0, nil + return 0, err } defer m.DecUsers(ctx) // Buffer the read data because of MM locks diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 10f1452ef..246bd87bc 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package signalfd provides basic signalfd file implementations. package signalfd import ( @@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 04e7110a3..a4ad625bb 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -163,7 +163,7 @@ afterSymlink: // verifyChildLocked verifies the hash of child against the already verified // hash of the parent to ensure the child is expected. verifyChild triggers a // sentry panic if unexpected modifications to the file system are detected. In -// noCrashOnVerificationFailure mode it returns a syserror instead. +// ErrorOnViolation mode it returns a syserror instead. // // Preconditions: // * fs.renameMu must be locked. diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 5788c661f..a5171b5ad 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -64,6 +64,10 @@ const ( // tree file for "/foo" is "/.merkle.verity.foo". merklePrefix = ".merkle.verity." + // merkleRootPrefix is the prefix of the Merkle tree root file. This + // needs to be different from merklePrefix to avoid name collision. + merkleRootPrefix = ".merkleroot.verity." + // merkleOffsetInParentXattr is the extended attribute name specifying the // offset of the child hash in its parent's Merkle tree. merkleOffsetInParentXattr = "user.merkle.offset" @@ -88,10 +92,8 @@ const ( ) var ( - // noCrashOnVerificationFailure indicates whether the sandbox should panic - // whenever verification fails. If true, an error is returned instead of - // panicking. This should only be set for tests. - noCrashOnVerificationFailure bool + // action specifies the action towards detected violation. + action ViolationAction // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. @@ -102,6 +104,18 @@ var ( // content. type HashAlgorithm int +// ViolationAction is a type specifying the action when an integrity violation +// is detected. +type ViolationAction int + +const ( + // PanicOnViolation terminates the sentry on detected violation. + PanicOnViolation ViolationAction = 0 + // ErrorOnViolation returns an error from the violating system call on + // detected violation. + ErrorOnViolation = 1 +) + // Currently supported hashing algorithms include SHA256 and SHA512. const ( SHA256 HashAlgorithm = iota @@ -166,7 +180,7 @@ type filesystem struct { // its children. So they shouldn't be enabled the same time. This lock // is for the whole file system to ensure that no more than one file is // enabled the same time. - verityMu sync.RWMutex + verityMu sync.RWMutex `state:"nosave"` } // InternalFilesystemOptions may be passed as @@ -196,10 +210,8 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // NoCrashOnVerificationFailure indicates whether the sandbox should - // panic whenever verification fails. If true, an error is returned - // instead of panicking. This should only be set for tests. - NoCrashOnVerificationFailure bool + // Action specifies the action on an integrity violation. + Action ViolationAction } // Name implements vfs.FilesystemType.Name. @@ -211,10 +223,10 @@ func (FilesystemType) Name() string { func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means -// unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +// unexpected modification to the file system is detected. In ErrorOnViolation +// mode, it returns EIO, otherwise it panic. func alertIntegrityViolation(msg string) error { - if noCrashOnVerificationFailure { + if action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -227,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + action = iopts.Action // Mount the lower file system. The lower file system is wrapped inside // verity, and should not be exposed or connected. @@ -255,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merklePrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -744,7 +756,7 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) // file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The // hash of the generated Merkle tree and the data size is returned. If fd // points to a regular file, the data is the content of the file. If fd points -// to a directory, the data is all hahes of its children, written to the Merkle +// to a directory, the data is all hashes of its children, written to the Merkle // tree file. // // Preconditions: fd.d.fs.verityMu must be locked. diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 6ced0afc9..30d8b4355 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -92,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - NoCrashOnVerificationFailure: true, + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + Alg: hashAlg, + AllowRuntimeEnable: true, + Action: ErrorOnViolation, }, }, }) @@ -239,6 +239,18 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, return fd, dataSize, err } +// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. +func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { + // Create the file in the underlying file system. + _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + if err != nil { + return nil, err + } + // Now open the verity file descriptor. + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) + return fd, err +} + // flipRandomBit randomly flips a bit in the file represented by fd. func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) @@ -349,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-empty-file" + fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newEmptyFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + var buf []byte + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != 0 { + t.Errorf("fd.Read got read length %d, expected 0", n) + } + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2cdcdfc1f..b8627a54f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -214,9 +214,11 @@ type Kernel struct { // netlinkPorts manages allocation of netlink socket port IDs. netlinkPorts *port.Manager - // saveErr is the error causing the sandbox to exit during save, if - // any. It is protected by extMu. - saveErr error `state:"nosave"` + // saveStatus is nil if the sandbox has not been saved, errSaved or + // errAutoSaved if it has been saved successfully, or the error causing the + // sandbox to exit during save. + // It is protected by extMu. + saveStatus error `state:"nosave"` // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` @@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager { return k.netlinkPorts } -// SaveError returns the sandbox error that caused the kernel to exit during -// save. -func (k *Kernel) SaveError() error { +var ( + errSaved = errors.New("sandbox has been successfully saved") + errAutoSaved = errors.New("sandbox has been successfully auto-saved") +) + +// SaveStatus returns the sandbox save status. If it was saved successfully, +// autosaved indicates whether save was triggered by autosave. If it was not +// saved successfully, err indicates the sandbox error that caused the kernel to +// exit during save. +func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) { + k.extMu.Lock() + defer k.extMu.Unlock() + switch k.saveStatus { + case nil: + return false, false, nil + case errSaved: + return true, false, nil + case errAutoSaved: + return true, true, nil + default: + return false, false, k.saveStatus + } +} + +// SetSaveSuccess sets the flag indicating that save completed successfully, if +// no status was already set. +func (k *Kernel) SetSaveSuccess(autosave bool) { k.extMu.Lock() defer k.extMu.Unlock() - return k.saveErr + if k.saveStatus == nil { + if autosave { + k.saveStatus = errAutoSaved + } else { + k.saveStatus = errSaved + } + } } // SetSaveError sets the sandbox error that caused the kernel to exit during @@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error { func (k *Kernel) SetSaveError(err error) { k.extMu.Lock() defer k.extMu.Unlock() - if k.saveErr == nil { - k.saveErr = err + if k.saveStatus == nil { + k.saveStatus = err } } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 7b23cbe86..2d47d2e82 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -63,10 +63,19 @@ func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { // ReaderWriterPair returns read-only and write-only FDs for vp. // // Preconditions: statusFlags should not contain an open access mode. -func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func (vp *VFSPipe) ReaderWriterPair(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { // Connected pipes share the same locks. locks := &vfs.FileLocks{} - return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + r, err := vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks) + if err != nil { + return nil, nil, err + } + w, err := vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + if err != nil { + r.DecRef(ctx) + return nil, nil, err + } + return r, w, nil } // Allocate implements vfs.FileDescriptionImpl.Allocate. @@ -85,7 +94,10 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s return nil, syserror.EINVAL } - fd := vp.newFD(mnt, vfsd, statusFlags, locks) + fd, err := vp.newFD(mnt, vfsd, statusFlags, locks) + if err != nil { + return nil, err + } // Named pipes have special blocking semantics during open: // @@ -137,16 +149,18 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { fd := &VFSPipeFD{ pipe: &vp.pipe, } fd.LockFD.Init(locks) - fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ + if err := fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, UseDentryMetadata: true, - }) + }); err != nil { + return nil, err + } switch { case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable(): @@ -160,7 +174,7 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l panic("invalid pipe flags: must be readable, writable, or both") } - return &fd.vfsfd + return &fd.vfsfd, nil } // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1abfe2201..cef58a590 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) { Signo: int32(linux.SIGTRAP), Code: code, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index 31198d772..db01e4a97 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -52,6 +52,9 @@ type Registry struct { mu sync.Mutex `state:"nosave"` semaphores map[int32]*Set lastIDUsed int32 + // indexes maintains a mapping between a set's index in virtual array and + // its identifier. + indexes map[int32]int32 } // Set represents a set of semaphores that can be operated atomically. @@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, semaphores: make(map[int32]*Set), + indexes: make(map[int32]int32), } } @@ -163,6 +167,9 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Apply system limits. + // + // Map semaphores and map indexes in a registry are of the same size, + // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } @@ -186,12 +193,43 @@ func (r *Registry) IPCInfo() *linux.SemInfo { SemMsl: linux.SEMMSL, SemOpm: linux.SEMOPM, SemUme: linux.SEMUME, - SemUsz: 0, // SemUsz not supported. + SemUsz: linux.SEMUSZ, SemVmx: linux.SEMVMX, SemAem: linux.SEMAEM, } } +// SemInfo returns a seminfo structure containing the same information as +// for IPC_INFO, except that SemUsz field returns the number of existing +// semaphore sets, and SemAem field returns the number of existing semaphores. +func (r *Registry) SemInfo() *linux.SemInfo { + r.mu.Lock() + defer r.mu.Unlock() + + info := r.IPCInfo() + info.SemUsz = uint32(len(r.semaphores)) + info.SemAem = uint32(r.totalSems()) + + return info +} + +// HighestIndex returns the index of the highest used entry in +// the kernel's array. +func (r *Registry) HighestIndex() int32 { + r.mu.Lock() + defer r.mu.Unlock() + + // By default, highest used index is 0 even though + // there is no semaphroe set. + var highestIndex int32 + for index := range r.indexes { + if index > highestIndex { + highestIndex = index + } + } + return highestIndex +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { @@ -202,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { if set == nil { return syserror.EINVAL } + index, found := r.findIndexByID(id) + if !found { + // Inconsistent state. + panic(fmt.Sprintf("unable to find an index for ID: %d", id)) + } set.mu.Lock() defer set.mu.Unlock() @@ -213,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { } delete(r.semaphores, set.ID) + delete(r.indexes, index) set.destroy() return nil } @@ -236,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File continue } if r.semaphores[id] == nil { + index, found := r.findFirstAvailableIndex() + if !found { + panic("unable to find an available index") + } + r.indexes[index] = id r.lastIDUsed = id r.semaphores[id] = set set.ID = id @@ -254,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set { return r.semaphores[id] } +// FindByIndex looks up a set given an index. +func (r *Registry) FindByIndex(index int32) *Set { + r.mu.Lock() + defer r.mu.Unlock() + + id, present := r.indexes[index] + if !present { + return nil + } + return r.semaphores[id] +} + func (r *Registry) findByKey(key int32) *Set { for _, v := range r.semaphores { if v.key == key { @@ -263,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set { return nil } +func (r *Registry) findIndexByID(id int32) (int32, bool) { + for k, v := range r.indexes { + if v == id { + return k, true + } + } + return 0, false +} + +func (r *Registry) findFirstAvailableIndex() (int32, bool) { + for index := int32(0); index < setsMax; index++ { + if _, present := r.indexes[index]; !present { + return index, true + } + } + return 0, false +} + func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 80a592c8f..073e14507 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -6,6 +6,9 @@ package(licenses = ["notice"]) go_template_instance( name = "shm_refs", out = "shm_refs.go", + consts = { + "enableLogging": "true", + }, package = "shm", prefix = "Shm", template = "//pkg/refsvfs2:refs_template", diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index e8cce37d0..2488ae7d5 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 78f718cfe..884966120 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c5137c282..16986244c 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -368,8 +368,8 @@ func (t *Task) exitChildren() { Signo: int32(sig), Code: arch.SignalInfoUser, } - siginfo.SetPid(int32(c.tg.pidns.tids[t])) - siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) + siginfo.SetPID(int32(c.tg.pidns.tids[t])) + siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) c.tg.signalHandlers.mu.Lock() c.sendSignalLocked(siginfo, true /* group */) c.tg.signalHandlers.mu.Unlock() @@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si info := &arch.SignalInfo{ Signo: int32(sig), } - info.SetPid(int32(receiver.tg.pidns.tids[t])) - info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.tids[t])) + info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { info.Code = arch.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 42dd3e278..75af3af79 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { Signo: int32(linux.SIGCHLD), Code: code, } - sigchld.SetPid(int32(t.tg.pidns.tids[target])) - sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + sigchld.SetPID(int32(t.tg.pidns.tids[target])) + sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) sigchld.SetStatus(status) // TODO(b/72102453): Set utime, stime. t.sendSignalLocked(sigchld, true /* group */) @@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { Signo: int32(sig), Code: t.ptraceCode, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } else { t.ptraceCode = int32(sig) t.ptraceSiginfo = nil @@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if parent == nil { // Tracer has detached and t was created by Kernel.CreateProcess(). // Pretend the parent is in an ancestor PID + user namespace. - info.SetPid(0) - info.SetUid(int32(auth.OverflowUID)) + info.SetPID(0) + info.SetUID(int32(auth.OverflowUID)) } else { - info.SetPid(int32(t.tg.pidns.tids[parent])) - info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + info.SetPID(int32(t.tg.pidns.tids[parent])) + info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } } t.tg.signalHandlers.mu.Lock() diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 7fd77925f..49e21026e 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp // Translations must be contiguous and in increasing order of // Translation.Source. if i > 0 && ts[i-1].Source.End != t.Source.Start { - return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t) + return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t) } // At least part of each Translation must be required. if t.Source.Intersect(required).Length() == 0 { diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4c8cd38ed..5ab2ef79f 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -36,12 +36,12 @@ type aioManager struct { contexts map[uint64]*AIOContext } -func (a *aioManager) destroy() { - a.mu.Lock() - defer a.mu.Unlock() +func (mm *MemoryManager) destroyAIOManager(ctx context.Context) { + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() - for _, ctx := range a.contexts { - ctx.destroy() + for id := range mm.aioManager.contexts { + mm.destroyAIOContextLocked(ctx, id) } } @@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { // be drained. // // Nil is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { - a.mu.Lock() - defer a.mu.Unlock() - ctx, ok := a.contexts[id] +// +// Precondition: mm.aioManager.mu is locked. +func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext { + aioCtx, ok := mm.aioManager.contexts[id] if !ok { return nil } - delete(a.contexts, id) - ctx.destroy() - return ctx + + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + + delete(mm.aioManager.contexts, id) + aioCtx.destroy() + return aioCtx } // lookupAIOContext looks up the given context. @@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() { } } -// Prepare reserves space for a new request, returning true if available. -// Returns false if the context is busy. -func (ctx *AIOContext) Prepare() bool { +// Prepare reserves space for a new request, returning nil if available. +// Returns EAGAIN if the context is busy and EINVAL if the context is dead. +func (ctx *AIOContext) Prepare() error { ctx.mu.Lock() defer ctx.mu.Unlock() + if ctx.dead { + // Context died after the caller looked it up. + return syserror.EINVAL + } if ctx.outstanding >= ctx.maxOutstanding { - return false + // Context is busy. + return syserror.EAGAIN } ctx.outstanding++ - return true + return nil } // PopRequest pops a completed request if available, this function does not do @@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // DestroyAIOContext destroys an asynchronous I/O context. It returns the // destroyed context. nil if the context does not exist. func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { - if _, ok := mm.LookupAIOContext(ctx, id); !ok { + if !mm.isValidAddr(ctx, id) { return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) - - return mm.aioManager.destroyAIOContext(id) + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() + return mm.destroyAIOContextLocked(ctx, id) } // LookupAIOContext looks up the given context. It returns false if the context @@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC return nil, false } - // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes - // from id). - var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) - if err != nil { + // Protect against 'id' that is inaccessible. + if !mm.isValidAddr(ctx, id) { return nil, false } return aioCtx, true } + +// isValidAddr determines if the address `id` is valid. (Linux also reads 4 +// bytes from id). +func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { + var buf [4]byte + _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + return err == nil +} diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index 3dabac1af..e8931922f 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -15,6 +15,6 @@ package mm // afterLoad is invoked by stateify. -func (a *AIOContext) afterLoad() { - a.requestReady = make(chan struct{}, 1) +func (ctx *AIOContext) afterLoad() { + ctx.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 09dbc06a4..120707429 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users)) } - mm.aioManager.destroy() + mm.destroyAIOManager(ctx) mm.metadataMu.Lock() exe := mm.executable diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index acac3d357..bc53bd41e 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } } + +// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be +// prepared after destruction. +func TestAIOPrepareAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + defer mm.DecUsers(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + aioCtx, ok := mm.LookupAIOContext(ctx, id) + if !ok { + t.Fatalf("AIOContext not found") + } + mm.DestroyAIOContext(ctx, id) + + // Prepare should fail because aioCtx should be destroyed. + if err := aioCtx.Prepare(); err != syserror.EINVAL { + t.Errorf("aioCtx.Prepare got err %v want nil", err) + } else if err == nil { + aioCtx.CancelPendingRequest() + } +} + +// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be +// looked up after memory manager is destroyed. +func TestAIOLookupAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + mm.DecUsers(ctx) + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + mm.DecUsers(ctx) // This destroys the AIOContext manager. + + if _, ok := mm.LookupAIOContext(ctx, id); ok { + t.Errorf("AIOContext found even after AIOContext manager is destroyed") + } +} diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 812ab80ef..aacd7ce70 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // facilitate vsyscall emulation. See patchSignalInfo. patchSignalInfo(regs, &c.signalInfo) return false - } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) { // The signal was generated by this process. That means // that it was an interrupt or something else that we // should bail for. Note that we ignore signals diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 5d01d21dd..2852b7387 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "arch_genrule", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -39,19 +39,19 @@ go_template_instance( template = ":defs_arm64", ) -genrule( +arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) -genrule( +arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 9742308d8..a9703baf6 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,6 +24,9 @@ go_binary( "defs_impl_arm64.go", "main.go", ], + # Use the libc malloc to avoid any extra dependencies. This is required to + # pass the sentry deps test. + system_malloc = True, visibility = [ "//pkg/sentry/platform/kvm:__pkg__", "//pkg/sentry/platform/ring0:__pkg__", diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 90a7b8392..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,11 +53,17 @@ func IsCanonical(addr uint64) bool { return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 } +// SwitchToUser performs an eret. +// +// The return value is the exception vector. +// +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) if switchOpts.Flush { - FlushTlbAll() + FlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index ef0d8974d..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -22,19 +22,25 @@ func storeAppASID(asid uintptr) // LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. func LocalFlushTlbAll() -// FlushTlbAll flush all tlb. +// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable. +func FlushTlbByVA(addr uintptr) + +// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. +func FlushTlbByASID(asid uintptr) + +// FlushTlbAll invalidates all tlb. func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) -// FPCR returns the value of FPCR register. +// GetFPCR returns the value of FPCR register. func GetFPCR() (value uintptr) // SetFPCR writes the FPCR value. func SetFPCR(value uintptr) -// FPSR returns the value of FPSR register. +// GetFPSR returns the value of FPSR register. func GetFPSR() (value uintptr) // SetFPSR writes the FPSR value. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 6f4923539..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,23 @@ #include "funcdata.h" #include "textflag.h" +#define TLBI_ASID_SHIFT 48 + +TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + DSB $10 // dsb(ishst) + WORD $0xd50883a1 // tlbi vale1is, x1 + DSB $11 // dsb(ish) + RET + +TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088341 // tlbi aside1is, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..fb7c5dc61 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -23,7 +23,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..ff6b71802 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, + ) +} + +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + +// PackSockExtendedErr packs an IP*_RECVERR socket control message. +func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { + return putCmsgStruct( + buf, + sockErr.CMsgLevel(), + sockErr.CMsgType(), + t.Arch().Width(), + sockErr, ) } @@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) + } + + if cmsgs.IP.SockErr != nil { + buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) } return buf @@ -416,17 +447,19 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + if cmsgs.IP.SockErr != nil { + space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) + } + + return space } // Parse parses a raw socket control message into portable objects. @@ -489,6 +522,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTimestamp = true + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp) + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +553,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg i += binary.AlignUp(length, width) default: @@ -528,6 +588,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..5b868216d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } -// RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Only allow known and safe flags. - // - // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the - // Socket interface's dependence on netstack. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument - } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT - var senderAddr linux.SockAddr + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } var senderAddrBuf []byte if senderRequested { senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) } - var controlBuf []byte - var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT +// RecvMsg implements socket.Socket.RecvMsg. +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + // Only allow known and safe flags. + if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument + } - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + var senderAddrBuf []byte + var controlBuf []byte + var msgFlags int + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + if dsts.IsEmpty() { + return 0, nil + } + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) - if flags&syscall.MSG_DONTWAIT == 0 { + n, err := copyToDst() + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. @@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: + switch unixCmsg.Header.Type { + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 3cc0d4f0f..3f587638f 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -320,7 +320,7 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -408,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = v - s.readCM = cms + s.readCM = socket.NewIPControlMessages(s.family, cms) atomic.StoreUint32(&s.readViewHasData, 1) return nil @@ -428,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -965,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1046,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &v, nil case linux.SO_BINDTODEVICE: - var v tcpip.BindToDeviceOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetBindToDevice() if v == 0 { var b primitive.ByteSlice return &b, nil @@ -1092,11 +1085,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1127,13 +1117,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1417,6 +1402,21 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) return &v, nil + case linux.IPV6_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1583,6 +1583,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) return &v, nil + case linux.IP_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) return &v, nil + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil + case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { return nil, syserr.ErrInvalidArgument @@ -1785,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - v := tcpip.BindToDeviceOption(0) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0)) } s := t.NetworkContext() if s == nil { @@ -1794,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - v := tcpip.BindToDeviceOption(nicID) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID)) } } return syserr.ErrUnknownDevice @@ -1864,8 +1878,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1888,10 +1902,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -2094,6 +2109,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2115,6 +2139,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name ep.SocketOptions().SetReceiveTClass(v != 0) return nil + case linux.IPV6_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2303,6 +2337,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetReceiveTOS(v != 0) return nil + case linux.IP_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil + case linux.IP_PKTINFO: if len(optVal) == 0 { return nil @@ -2325,6 +2370,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in ep.SocketOptions().SetHeaderIncluded(v != 0) return nil + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { return syserr.ErrInvalidArgument @@ -2360,10 +2417,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2437,11 +2492,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_MULTICAST_IF, linux.IPV6_MULTICAST_LOOP, linux.IPV6_RECVDSTOPTS, - linux.IPV6_RECVERR, linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2472,7 +2525,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { linux.IP_PKTINFO, linux.IP_PKTOPTIONS, linux.IP_MTU_DISCOVER, - linux.IP_RECVERR, linux.IP_RECVTTL, linux.IP_RECVTOS, linux.IP_MTU, @@ -2701,7 +2753,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // We need to peek beyond the first message. dst = dst.DropFirst(n) num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) + n, err := s.Endpoint.Peek(dsts) // TODO(b/78348848): Handle peek timestamp. if err != nil { return int64(n), syserr.TranslateNetstackError(err).ToError() @@ -2745,15 +2797,19 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasInq: s.readCM.HasInq, + Inq: s.readCM.Inq, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, + SockErr: s.readCM.SockErr, }, } } @@ -2770,9 +2826,66 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb(). +func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { + so := s.Endpoint.SocketOptions() + err := so.DequeueErr() + if err == nil { + return nil + } + + // Update socket error to reflect ICMP errors in queue. + if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + so.SetLastError(nextErr.Err) + } else if err.ErrOrigin.IsICMPErr() { + so.SetLastError(nil) + } + return err +} + +// addrFamilyFromNetProto returns the address family identifier for the given +// network protocol. +func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { + switch net { + case header.IPv4ProtocolNumber: + return linux.AF_INET + case header.IPv6ProtocolNumber: + return linux.AF_INET6 + default: + panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net)) + } +} + +// recvErr handles MSG_ERRQUEUE for recvmsg(2). +// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). +func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + sockErr := s.dequeueErr() + if sockErr == nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + } + + // The payload of the original packet that caused the error is passed as + // normal data via msg_iovec. -- recvmsg(2) + msgFlags := linux.MSG_ERRQUEUE + if int(dst.NumBytes()) < len(sockErr.Payload) { + msgFlags |= linux.MSG_TRUNC + } + n, err := dst.CopyOut(t, sockErr.Payload) + + // The original destination address of the datagram that caused the error is + // supplied via msg_name. -- recvmsg(2) + dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst) + cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})} + return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err) +} + // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { + if flags&linux.MSG_ERRQUEUE != 0 { + return s.recvErr(t, dst) + } + trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9049e8a21..97729dacc 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -44,7 +44,134 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// errOriginToLinux maps tcpip socket origin to Linux socket origin constants. +func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 { + switch origin { + case tcpip.SockExtErrorOriginNone: + return linux.SO_EE_ORIGIN_NONE + case tcpip.SockExtErrorOriginLocal: + return linux.SO_EE_ORIGIN_LOCAL + case tcpip.SockExtErrorOriginICMP: + return linux.SO_EE_ORIGIN_ICMP + case tcpip.SockExtErrorOriginICMP6: + return linux.SO_EE_ORIGIN_ICMP6 + default: + panic(fmt.Sprintf("unknown socket origin: %d", origin)) + } +} + +// sockErrCmsgToLinux converts SockError control message from tcpip format to +// Linux format. +func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { + if sockErr == nil { + return nil + } + + ee := linux.SockExtendedErr{ + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Origin: errOriginToLinux(sockErr.ErrOrigin), + Type: sockErr.ErrType, + Code: sockErr.ErrCode, + Info: sockErr.ErrInfo, + } + + switch sockErr.NetProto { + case header.IPv4ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet) + } + return errMsg + case header.IPv6ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet6) + } + return errMsg + default: + panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto)) + } +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + SockErr: sockErrCmsgToLinux(cmgs.SockErr), + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr linux.SockErrCMsg } // Release releases Unix domain socket credentials and rights. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0247e93fa..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -746,9 +746,6 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -840,12 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } return nil } @@ -922,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index cff442846..b815e498f 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 0bf313a13..c2285f796 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := ctx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := ctx.Prepare(); err != nil { + return err } if eventFile != nil { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 8db587401..c33571f43 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } + file, err := d.Inode.GetFile(t, d, fileFlags) + if err != nil { + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + } + defer file.DecRef(t) + // Truncate is called when O_TRUNC is specified for any kind of // existing Dirent. Behavior is delegated to the entry's Truncate // implementation. @@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } - file, err := d.Inode.GetFile(t, d, fileFlags) - if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) - } - defer file.DecRef(t) - // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index a1601676f..1166cd7bb 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -150,14 +150,33 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal buf := args[3].Pointer() r := t.IPCNamespace().SemaphoreRegistry() info := r.IPCInfo() - _, err := info.CopyOut(t, buf) - // TODO(gvisor.dev/issue/137): Return the index of the highest used entry. - return 0, nil, err + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.SemInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil - case linux.SEM_INFO, - linux.SEM_STAT, - linux.SEM_STAT_ANY: + case linux.SEM_STAT: + arg := args[3].Pointer() + // id is an index in SEM_STAT. + semid, ds, err := semStat(t, id) + if err != nil { + return 0, nil, err + } + if _, err := ds.CopyOut(t, arg); err != nil { + return 0, nil, err + } + return uintptr(semid), nil, err + case linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -202,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { return set.GetStat(creds) } +func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByIndex(index) + if set == nil { + return 0, nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + ds, err := set.GetStat(creds) + return set.ID, ds, err +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index e748d33d8..d639c9bf7 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(target.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) + info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) if err := target.SendGroupSignal(info); err != syserror.ESRCH { return 0, nil, err } @@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) if err == syserror.ESRCH { // ESRCH is ignored because it means the task @@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. if err := tg.SendSignal(info); err != syserror.ESRCH { lastErr = err @@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI Signo: int32(sig), Code: arch.SignalInfoTkill, } - info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..4adfa6637 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 983f8d396..8e7ac0ffe 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal si := arch.SignalInfo{ Signo: int32(linux.SIGCHLD), } - si.SetPid(int32(wr.TID)) - si.SetUid(int32(wr.UID)) + si.SetPID(int32(wr.TID)) + si.SetUID(int32(wr.UID)) // TODO(b/73541790): convert kernel.ExitStatus to functions and make // WaitResult.Status a linux.WaitStatus. s := syscall.WaitStatus(wr.Status) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 6d0a38330..1365a5a62 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := aioCtx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := aioCtx.Prepare(); err != nil { + return err } if eventFD != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index ee38fdca0..6986e39fe 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -42,7 +42,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return syserror.EINVAL } - r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + if err != nil { + return err + } defer r.DecRef(t) defer w.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..987012acc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 8f070ed53..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -101,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -112,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go index d462c3eef..e8315326d 100644 --- a/pkg/shim/v1/proc/process.go +++ b/pkg/shim/v1/proc/process.go @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package proc contains process-related utilities. package proc import ( diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD index 05c595bc9..e5b6bf186 100644 --- a/pkg/shim/v1/shim/BUILD +++ b/pkg/shim/v1/shim/BUILD @@ -8,6 +8,7 @@ go_library( "api.go", "platform.go", "service.go", + "shim.go", ], visibility = [ "//pkg/shim:__subpackages__", diff --git a/pkg/shim/v1/shim/shim.go b/pkg/shim/v1/shim/shim.go new file mode 100644 index 000000000..1855a8769 --- /dev/null +++ b/pkg/shim/v1/shim/shim.go @@ -0,0 +1,17 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package shim contains the core containerd shim implementation. +package shim diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go index 07e346654..21e75d16d 100644 --- a/pkg/shim/v1/utils/utils.go +++ b/pkg/shim/v1/utils/utils.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package utils contains utility functions. package utils import ( diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD index f37fefddc..b0e8daa51 100644 --- a/pkg/shim/v2/BUILD +++ b/pkg/shim/v2/BUILD @@ -22,6 +22,7 @@ go_library( "//runsc/specutils", "@com_github_burntsushi_toml//:go_default_library", "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_cgroups//stats/v1:go_default_library", "@com_github_containerd_console//:go_default_library", "@com_github_containerd_containerd//api/events:go_default_library", "@com_github_containerd_containerd//api/types/task:go_default_library", diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go index cba403cae..6aaf5fab8 100644 --- a/pkg/shim/v2/service.go +++ b/pkg/shim/v2/service.go @@ -28,6 +28,7 @@ import ( "github.com/BurntSushi/toml" "github.com/containerd/cgroups" + cgroupsstats "github.com/containerd/cgroups/stats/v1" "github.com/containerd/console" "github.com/containerd/containerd/api/events" "github.com/containerd/containerd/api/types/task" @@ -735,48 +736,48 @@ func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI. // as runc. // // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 - metrics := &cgroups.Metrics{ - CPU: &cgroups.CPUStat{ - Usage: &cgroups.CPUUsage{ + metrics := &cgroupsstats.Metrics{ + CPU: &cgroupsstats.CPUStat{ + Usage: &cgroupsstats.CPUUsage{ Total: stats.Cpu.Usage.Total, Kernel: stats.Cpu.Usage.Kernel, User: stats.Cpu.Usage.User, PerCPU: stats.Cpu.Usage.Percpu, }, - Throttling: &cgroups.Throttle{ + Throttling: &cgroupsstats.Throttle{ Periods: stats.Cpu.Throttling.Periods, ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, ThrottledTime: stats.Cpu.Throttling.ThrottledTime, }, }, - Memory: &cgroups.MemoryStat{ + Memory: &cgroupsstats.MemoryStat{ Cache: stats.Memory.Cache, - Usage: &cgroups.MemoryEntry{ + Usage: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Usage.Limit, Usage: stats.Memory.Usage.Usage, Max: stats.Memory.Usage.Max, Failcnt: stats.Memory.Usage.Failcnt, }, - Swap: &cgroups.MemoryEntry{ + Swap: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Swap.Limit, Usage: stats.Memory.Swap.Usage, Max: stats.Memory.Swap.Max, Failcnt: stats.Memory.Swap.Failcnt, }, - Kernel: &cgroups.MemoryEntry{ + Kernel: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.Kernel.Limit, Usage: stats.Memory.Kernel.Usage, Max: stats.Memory.Kernel.Max, Failcnt: stats.Memory.Kernel.Failcnt, }, - KernelTCP: &cgroups.MemoryEntry{ + KernelTCP: &cgroupsstats.MemoryEntry{ Limit: stats.Memory.KernelTCP.Limit, Usage: stats.Memory.KernelTCP.Usage, Max: stats.Memory.KernelTCP.Max, Failcnt: stats.Memory.KernelTCP.Failcnt, }, }, - Pids: &cgroups.PidsStat{ + Pids: &cgroupsstats.PidsStat{ Current: stats.Pids.Current, Limit: stats.Pids.Limit, }, diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go index d3931c952..2b1609af0 100644 --- a/pkg/state/tests/integer_test.go +++ b/pkg/state/tests/integer_test.go @@ -20,21 +20,21 @@ import ( ) var ( - allIntTs = []int{-1, 0, 1} - allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} - allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} - allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} - allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} - allUintTs = []uint{0, 1} - allUintptrs = []uintptr{0, 1, ^uintptr(0)} - allUint8s = []uint8{0, 1, math.MaxUint8} - allUint16s = []uint16{0, 1, math.MaxUint16} - allUint32s = []uint32{0, 1, math.MaxUint32} - allUint64s = []uint64{0, 1, math.MaxUint64} + allBasicInts = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allBasicUints = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} ) var allInts = flatten( - allIntTs, + allBasicInts, allInt8s, allInt16s, allInt32s, @@ -42,7 +42,7 @@ var allInts = flatten( ) var allUints = flatten( - allUintTs, + allBasicUints, allUintptrs, allUint8s, allUint16s, diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index be5bc99fc..28e62abbb 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -10,15 +10,34 @@ exports_files(["LICENSE"]) go_template( name = "generic_atomicptr", - srcs = ["atomicptr_unsafe.go"], + srcs = ["generic_atomicptr_unsafe.go"], types = [ "Value", ], ) go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + +go_template( name = "generic_seqatomic", - srcs = ["seqatomic_unsafe.go"], + srcs = ["generic_seqatomic_unsafe.go"], types = [ "Value", ], diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmaptest/BUILD new file mode 100644 index 000000000..3f71ae97d --- /dev/null +++ b/pkg/sync/atomicptrmaptest/BUILD @@ -0,0 +1,57 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package( + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +go_template_instance( + name = "test_atomicptrmap", + out = "test_atomicptrmap_unsafe.go", + package = "atomicptrmap", + prefix = "test", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_template_instance( + name = "test_atomicptrmap_sharded", + out = "test_atomicptrmap_sharded_unsafe.go", + consts = { + "ShardOrder": "4", + }, + package = "atomicptrmap", + prefix = "test", + suffix = "Sharded", + template = "//pkg/sync:generic_atomicptrmap", + types = { + "Key": "int64", + "Value": "testValue", + }, +) + +go_library( + name = "atomicptrmap", + testonly = 1, + srcs = [ + "atomicptrmap.go", + "test_atomicptrmap_sharded_unsafe.go", + "test_atomicptrmap_unsafe.go", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + +go_test( + name = "atomicptrmap_test", + size = "small", + srcs = ["atomicptrmap_test.go"], + library = ":atomicptrmap", + deps = ["//pkg/sync"], +) diff --git a/tools/vm/test.cc b/pkg/sync/atomicptrmaptest/atomicptrmap.go index c0ceacda1..867821ce9 100644 --- a/tools/vm/test.cc +++ b/pkg/sync/atomicptrmaptest/atomicptrmap.go @@ -12,16 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gtest/gtest.h" +// Package atomicptrmap instantiates generic_atomicptrmap for testing. +package atomicptrmap -namespace { - -TEST(Image, Sanity0) { - // Do nothing (in shard 0). -} - -TEST(Image, Sanity1) { - // Do nothing (in shard 1). +type testValue struct { + val int } - -} // namespace diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go new file mode 100644 index 000000000..75a9997ef --- /dev/null +++ b/pkg/sync/atomicptrmaptest/atomicptrmap_test.go @@ -0,0 +1,635 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomicptrmap + +import ( + "context" + "fmt" + "math/rand" + "reflect" + "runtime" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" +) + +func TestConsistencyWithGoMap(t *testing.T) { + const maxKey = 16 + var vals [4]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + vals[i] = new(testValue) + } + var ( + m = make(map[int64]*testValue) + apm testAtomicPtrMap + ) + for i := 0; i < 100000; i++ { + // Apply a random operation to both m and apm and expect them to have + // the same result. Bias toward CompareAndSwap, which has the most + // cases; bias away from Range and RangeRepeatable, which are + // relatively expensive. + switch rand.Intn(10) { + case 0, 1: // Load + key := rand.Int63n(maxKey) + want := m[key] + got := apm.Load(key) + t.Logf("Load(%d) = %p", key, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 2, 3: // Swap + key := rand.Int63n(maxKey) + val := vals[rand.Intn(len(vals))] + want := m[key] + if val != nil { + m[key] = val + } else { + delete(m, key) + } + got := apm.Swap(key, val) + t.Logf("Swap(%d, %p) = %p", key, val, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 4, 5, 6, 7: // CompareAndSwap + key := rand.Int63n(maxKey) + oldVal := vals[rand.Intn(len(vals))] + newVal := vals[rand.Intn(len(vals))] + want := m[key] + if want == oldVal { + if newVal != nil { + m[key] = newVal + } else { + delete(m, key) + } + } + got := apm.CompareAndSwap(key, oldVal, newVal) + t.Logf("CompareAndSwap(%d, %p, %p) = %p", key, oldVal, newVal, got) + if got != want { + t.Fatalf("got %p, wanted %p", got, want) + } + case 8: // Range + got := make(map[int64]*testValue) + var ( + haveDup = false + dup int64 + ) + apm.Range(func(key int64, val *testValue) bool { + if _, ok := got[key]; ok && !haveDup { + haveDup = true + dup = key + } + got[key] = val + return true + }) + t.Logf("Range() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + if haveDup { + t.Fatalf("got duplicate key %d", dup) + } + case 9: // RangeRepeatable + got := make(map[int64]*testValue) + apm.RangeRepeatable(func(key int64, val *testValue) bool { + got[key] = val + return true + }) + t.Logf("RangeRepeatable() = %v", got) + if !reflect.DeepEqual(got, m) { + t.Fatalf("got %v, wanted %v", got, m) + } + } + } +} + +func TestConcurrentHeterogeneous(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var ( + apm testAtomicPtrMap + wg sync.WaitGroup + ) + defer func() { + cancel() + wg.Wait() + }() + + possibleKeyValuePairs := make(map[int64]map[*testValue]struct{}) + addKeyValuePair := func(key int64, val *testValue) { + values := possibleKeyValuePairs[key] + if values == nil { + values = make(map[*testValue]struct{}) + possibleKeyValuePairs[key] = values + } + values[val] = struct{}{} + } + + const numValuesPerKey = 4 + + // These goroutines use keys not used by any other goroutine. + const numPrivateKeys = 3 + for i := 0; i < numPrivateKeys; i++ { + key := int64(i) + var vals [numValuesPerKey]*testValue + for i := 1; /* leave vals[0] nil */ i < len(vals); i++ { + val := new(testValue) + vals[i] = val + addKeyValuePair(key, val) + } + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + var stored *testValue + for ctx.Err() == nil { + switch r.Intn(4) { + case 0: + got := apm.Load(key) + if got != stored { + t.Errorf("Load(%d): got %p, wanted %p", key, got, stored) + return + } + case 1: + val := vals[r.Intn(len(vals))] + want := stored + stored = val + got := apm.Swap(key, val) + if got != want { + t.Errorf("Swap(%d, %p): got %p, wanted %p", key, val, got, want) + return + } + case 2, 3: + oldVal := vals[r.Intn(len(vals))] + newVal := vals[r.Intn(len(vals))] + want := stored + if stored == oldVal { + stored = newVal + } + got := apm.CompareAndSwap(key, oldVal, newVal) + if got != want { + t.Errorf("CompareAndSwap(%d, %p, %p): got %p, wanted %p", key, oldVal, newVal, got, want) + return + } + } + } + }() + } + + // These goroutines share a small set of keys. + const numSharedKeys = 2 + var ( + sharedKeys [numSharedKeys]int64 + sharedValues = make(map[int64][]*testValue) + sharedValuesSet = make(map[int64]map[*testValue]struct{}) + ) + for i := range sharedKeys { + key := int64(numPrivateKeys + i) + sharedKeys[i] = key + vals := make([]*testValue, numValuesPerKey) + valsSet := make(map[*testValue]struct{}) + for j := range vals { + val := new(testValue) + vals[j] = val + valsSet[val] = struct{}{} + addKeyValuePair(key, val) + } + sharedValues[key] = vals + sharedValuesSet[key] = valsSet + } + randSharedValue := func(r *rand.Rand, key int64) *testValue { + vals := sharedValues[key] + return vals[r.Intn(len(vals))] + } + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + keyIndex := r.Intn(len(sharedKeys)) + key := sharedKeys[keyIndex] + var ( + op string + got *testValue + ) + switch r.Intn(4) { + case 0: + op = "Load" + got = apm.Load(key) + case 1: + op = "Swap" + got = apm.Swap(key, randSharedValue(r, key)) + case 2, 3: + op = "CompareAndSwap" + got = apm.CompareAndSwap(key, randSharedValue(r, key), randSharedValue(r, key)) + } + if got != nil { + valsSet := sharedValuesSet[key] + if _, ok := valsSet[got]; !ok { + t.Errorf("%s: got key %d, value %p; expected value in %v", op, key, got, valsSet) + return + } + } + } + }() + } + + // This goroutine repeatedly searches for unused keys. + wg.Add(1) + go func() { + defer wg.Done() + r := rand.New(rand.NewSource(rand.Int63())) + for ctx.Err() == nil { + key := -1 - r.Int63() + if got := apm.Load(key); got != nil { + t.Errorf("Load(%d): got %p, wanted nil", key, got) + } + } + }() + + // This goroutine repeatedly calls RangeRepeatable() and checks that each + // key corresponds to an expected value. + wg.Add(1) + go func() { + defer wg.Done() + abort := false + for !abort && ctx.Err() == nil { + apm.RangeRepeatable(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("RangeRepeatable: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("RangeRepeatable: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + return true + }) + } + }() + + // Finally, the main goroutine spins for the length of the test calling + // Range() and checking that each key that it observes is unique and + // corresponds to an expected value. + seenKeys := make(map[int64]struct{}) + const testDuration = 5 * time.Second + end := time.Now().Add(testDuration) + abort := false + for time.Now().Before(end) { + apm.Range(func(key int64, val *testValue) bool { + values, ok := possibleKeyValuePairs[key] + if !ok { + t.Errorf("Range: got invalid key %d", key) + abort = true + return false + } + if _, ok := values[val]; !ok { + t.Errorf("Range: got key %d, value %p; expected one of %v", key, val, values) + abort = true + return false + } + if _, ok := seenKeys[key]; ok { + t.Errorf("Range: got duplicate key %d", key) + abort = true + return false + } + seenKeys[key] = struct{}{} + return true + }) + if abort { + break + } + for k := range seenKeys { + delete(seenKeys, k) + } + } +} + +type benchmarkableMap interface { + Load(key int64) *testValue + Store(key int64, val *testValue) + LoadOrStore(key int64, val *testValue) (*testValue, bool) + Delete(key int64) +} + +// rwMutexMap implements benchmarkableMap for a RWMutex-protected Go map. +type rwMutexMap struct { + mu sync.RWMutex + m map[int64]*testValue +} + +func (m *rwMutexMap) Load(key int64) *testValue { + m.mu.RLock() + defer m.mu.RUnlock() + return m.m[key] +} + +func (m *rwMutexMap) Store(key int64, val *testValue) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + m.m[key] = val +} + +func (m *rwMutexMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[int64]*testValue) + } + if oldVal, ok := m.m[key]; ok { + return oldVal, true + } + m.m[key] = val + return val, false +} + +func (m *rwMutexMap) Delete(key int64) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, key) +} + +// syncMap implements benchmarkableMap for a sync.Map. +type syncMap struct { + m sync.Map +} + +func (m *syncMap) Load(key int64) *testValue { + val, ok := m.m.Load(key) + if !ok { + return nil + } + return val.(*testValue) +} + +func (m *syncMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *syncMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + actual, loaded := m.m.LoadOrStore(key, val) + return actual.(*testValue), loaded +} + +func (m *syncMap) Delete(key int64) { + m.m.Delete(key) +} + +// benchmarkableAtomicPtrMap implements benchmarkableMap for testAtomicPtrMap. +type benchmarkableAtomicPtrMap struct { + m testAtomicPtrMap +} + +func (m *benchmarkableAtomicPtrMap) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMap) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMap) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMap) Delete(key int64) { + m.m.Store(key, nil) +} + +// benchmarkableAtomicPtrMapSharded implements benchmarkableMap for testAtomicPtrMapSharded. +type benchmarkableAtomicPtrMapSharded struct { + m testAtomicPtrMapSharded +} + +func (m *benchmarkableAtomicPtrMapSharded) Load(key int64) *testValue { + return m.m.Load(key) +} + +func (m *benchmarkableAtomicPtrMapSharded) Store(key int64, val *testValue) { + m.m.Store(key, val) +} + +func (m *benchmarkableAtomicPtrMapSharded) LoadOrStore(key int64, val *testValue) (*testValue, bool) { + if prev := m.m.CompareAndSwap(key, nil, val); prev != nil { + return prev, true + } + return val, false +} + +func (m *benchmarkableAtomicPtrMapSharded) Delete(key int64) { + m.m.Store(key, nil) +} + +var mapImpls = [...]struct { + name string + ctor func() benchmarkableMap +}{ + { + name: "RWMutexMap", + ctor: func() benchmarkableMap { + return new(rwMutexMap) + }, + }, + { + name: "SyncMap", + ctor: func() benchmarkableMap { + return new(syncMap) + }, + }, + { + name: "AtomicPtrMap", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMap) + }, + }, + { + name: "AtomicPtrMapSharded", + ctor: func() benchmarkableMap { + return new(benchmarkableAtomicPtrMapSharded) + }, + }, +} + +func benchmarkStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLoadOrStoreDelete(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.LoadOrStore(int64(i), val) + } + for i := 0; i < b.N; i++ { + m.Delete(int64(i)) + } +} + +func BenchmarkLoadOrStoreDelete(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLoadOrStoreDelete(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupPositive(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(i)) + } +} + +func BenchmarkLookupPositive(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupPositive(b, mapImpl.ctor) + }) + } +} + +func benchmarkLookupNegative(b *testing.B, mapCtor func() benchmarkableMap) { + m := mapCtor() + val := &testValue{} + for i := 0; i < b.N; i++ { + m.Store(int64(i), val) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Load(int64(-1 - i)) + } +} + +func BenchmarkLookupNegative(b *testing.B) { + for _, mapImpl := range mapImpls { + b.Run(mapImpl.name, func(b *testing.B) { + benchmarkLookupNegative(b, mapImpl.ctor) + }) + } +} + +type benchmarkConcurrentOptions struct { + // loadsPerMutationPair is the number of map lookups between each + // insertion/deletion pair. + loadsPerMutationPair int + + // If changeKeys is true, the keys used by each goroutine change between + // iterations of the test. + changeKeys bool +} + +func benchmarkConcurrent(b *testing.B, mapCtor func() benchmarkableMap, opts benchmarkConcurrentOptions) { + var ( + started sync.WaitGroup + workers sync.WaitGroup + ) + started.Add(1) + + m := mapCtor() + val := &testValue{} + // Insert a large number of unused elements into the map so that used + // elements are distributed throughout memory. + for i := 0; i < 10000; i++ { + m.Store(int64(-1-i), val) + } + // n := ceil(b.N / (opts.loadsPerMutationPair + 2)) + n := (b.N + opts.loadsPerMutationPair + 1) / (opts.loadsPerMutationPair + 2) + for i, procs := 0, runtime.GOMAXPROCS(0); i < procs; i++ { + workerID := i + workers.Add(1) + go func() { + defer workers.Done() + started.Wait() + for i := 0; i < n; i++ { + var key int64 + if opts.changeKeys { + key = int64(workerID*n + i) + } else { + key = int64(workerID) + } + m.LoadOrStore(key, val) + for j := 0; j < opts.loadsPerMutationPair; j++ { + m.Load(key) + } + m.Delete(key) + } + }() + } + + b.ResetTimer() + started.Done() + workers.Wait() +} + +func BenchmarkConcurrent(b *testing.B) { + changeKeysChoices := [...]struct { + name string + val bool + }{ + {"FixedKeys", false}, + {"ChangingKeys", true}, + } + writePcts := [...]struct { + name string + loadsPerMutationPair int + }{ + {"1PercentWrites", 198}, + {"10PercentWrites", 18}, + {"50PercentWrites", 2}, + } + for _, changeKeys := range changeKeysChoices { + for _, writePct := range writePcts { + for _, mapImpl := range mapImpls { + name := fmt.Sprintf("%s_%s_%s", changeKeys.name, writePct.name, mapImpl.name) + b.Run(name, func(b *testing.B) { + benchmarkConcurrent(b, mapImpl.ctor, benchmarkConcurrentOptions{ + loadsPerMutationPair: writePct.loadsPerMutationPair, + changeKeys: changeKeys.val, + }) + }) + } + } + } +} diff --git a/pkg/sync/atomicptr_unsafe.go b/pkg/sync/generic_atomicptr_unsafe.go index 525c4beed..82b6df18c 100644 --- a/pkg/sync/atomicptr_unsafe.go +++ b/pkg/sync/generic_atomicptr_unsafe.go @@ -3,9 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( "sync/atomic" diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/generic_atomicptrmap_unsafe.go new file mode 100644 index 000000000..c70dda6dd --- /dev/null +++ b/pkg/sync/generic_atomicptrmap_unsafe.go @@ -0,0 +1,503 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package atomicptrmap doesn't exist. This file must be instantiated using the +// go_template_instance rule in tools/go_generics/defs.bzl. +package atomicptrmap + +import ( + "reflect" + "runtime" + "sync/atomic" + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// Key is a required type parameter. +type Key struct{} + +// Value is a required type parameter. +type Value struct{} + +const ( + // ShardOrder is an optional parameter specifying the base-2 log of the + // number of shards per AtomicPtrMap. Higher values of ShardOrder reduce + // unnecessary synchronization between unrelated concurrent operations, + // improving performance for write-heavy workloads, but increase memory + // usage for small maps. + ShardOrder = 0 +) + +// Hasher is an optional type parameter. If Hasher is provided, it must define +// the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps. +type Hasher struct { + defaultHasher +} + +// defaultHasher is the default Hasher. This indirection exists because +// defaultHasher must exist even if a custom Hasher is provided, to prevent the +// Go compiler from complaining about defaultHasher's unused imports. +type defaultHasher struct { + fn func(unsafe.Pointer, uintptr) uintptr + seed uintptr +} + +// Init initializes the Hasher. +func (h *defaultHasher) Init() { + h.fn = sync.MapKeyHasher(map[Key]*Value(nil)) + h.seed = sync.RandUintptr() +} + +// Hash returns the hash value for the given Key. +func (h *defaultHasher) Hash(key Key) uintptr { + return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed) +} + +var hasher Hasher + +func init() { + hasher.Init() +} + +// An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are +// safe for concurrent use from multiple goroutines without additional +// synchronization. +// +// The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for +// use. AtomicPtrMaps must not be copied after first use. +// +// sync.Map may be faster than AtomicPtrMap if most operations on the map are +// concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in +// other circumstances. +type AtomicPtrMap struct { + // AtomicPtrMap is implemented as a hash table with the following + // properties: + // + // * Collisions are resolved with quadratic probing. Of the two major + // alternatives, Robin Hood linear probing makes it difficult for writers + // to execute in parallel, and bucketing is less effective in Go due to + // lack of SIMD. + // + // * The table is optionally divided into shards indexed by hash to further + // reduce unnecessary synchronization. + + shards [1 << ShardOrder]apmShard +} + +func (m *AtomicPtrMap) shard(hash uintptr) *apmShard { + // Go defines right shifts >= width of shifted unsigned operand as 0, so + // this is correct even if ShardOrder is 0 (although nogo complains because + // nogo is dumb). + const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder + index := hash >> indexLSB + return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{})))) +} + +type apmShard struct { + apmShardMutationData + _ [apmShardMutationDataPadding]byte + apmShardLookupData + _ [apmShardLookupDataPadding]byte +} + +type apmShardMutationData struct { + dirtyMu sync.Mutex // serializes slot transitions out of empty + dirty uintptr // # slots with val != nil + count uintptr // # slots with val != nil and val != tombstone() + rehashMu sync.Mutex // serializes rehashing +} + +type apmShardLookupData struct { + seq sync.SeqCount // allows atomic reads of slots+mask + slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq + mask uintptr // always (a power of 2) - 1; protected by rehashMu/seq +} + +const ( + cacheLineBytes = 64 + // Cache line padding is enabled if sharding is. + apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise + // The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) % + // cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes). + apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1) + apmShardMutationDataPadding = apmEnablePadding * apmShardMutationDataRequiredPadding + apmShardLookupDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1) + apmShardLookupDataPadding = apmEnablePadding * apmShardLookupDataRequiredPadding + + // These define fractional thresholds for when apmShard.rehash() is called + // (i.e. the load factor) and when it rehases to a larger table + // respectively. They are chosen such that the rehash threshold = the + // expansion threshold + 1/2, so that when reuse of deleted slots is rare + // or non-existent, rehashing occurs after the insertion of at least 1/2 + // the table's size in new entries, which is acceptably infrequent. + apmRehashThresholdNum = 2 + apmRehashThresholdDen = 3 + apmExpansionThresholdNum = 1 + apmExpansionThresholdDen = 6 +) + +type apmSlot struct { + // slot states are indicated by val: + // + // * Empty: val == nil; key is meaningless. May transition to full or + // evacuated with dirtyMu locked. + // + // * Full: val != nil, tombstone(), or evacuated(); key is immutable. val + // is the Value mapped to key. May transition to deleted or evacuated. + // + // * Deleted: val == tombstone(); key is still immutable. key is mapped to + // no Value. May transition to full or evacuated. + // + // * Evacuated: val == evacuated(); key is immutable. Set by rehashing on + // slots that have already been moved, requiring readers to wait for + // rehashing to complete and use the new table. Terminal state. + // + // Note that once val is non-nil, it cannot become nil again. That is, the + // transition from empty to non-empty is irreversible for a given slot; + // the only way to create more empty slots is by rehashing. + val unsafe.Pointer + key Key +} + +func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot { + return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{}))) +} + +var tombstoneObj byte + +func tombstone() unsafe.Pointer { + return unsafe.Pointer(&tombstoneObj) +} + +var evacuatedObj byte + +func evacuated() unsafe.Pointer { + return unsafe.Pointer(&evacuatedObj) +} + +// Load returns the Value stored in m for key. +func (m *AtomicPtrMap) Load(key Key) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + return nil + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Empty slot; end of probe sequence. + return nil + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + i = (i + inc) & mask + inc++ + } +} + +// Store stores the Value val for key. +func (m *AtomicPtrMap) Store(key Key, val *Value) { + m.maybeCompareAndSwap(key, false, nil, val) +} + +// Swap stores the Value val for key and returns the previously-mapped Value. +func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value { + return m.maybeCompareAndSwap(key, false, nil, val) +} + +// CompareAndSwap checks that the Value stored for key is oldVal; if it is, it +// stores the Value newVal for key. CompareAndSwap returns the previous Value +// stored for key, whether or not it stores newVal. +func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value { + return m.maybeCompareAndSwap(key, true, oldVal, newVal) +} + +func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value { + hash := hasher.Hash(key) + shard := m.shard(hash) + oldVal := tombstone() + if typedOldVal != nil { + oldVal = unsafe.Pointer(typedOldVal) + } + newVal := tombstone() + if typedNewVal != nil { + newVal = unsafe.Pointer(typedNewVal) + } + +retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Need to allocate a table before insertion. + shard.rehash(nil) + goto retry + } + + i := hash & mask + inc := uintptr(1) + for { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil { + if (compare && oldVal != tombstone()) || newVal == tombstone() { + return nil + } + // Try to grab this slot for ourselves. + shard.dirtyMu.Lock() + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == nil { + // Check if we need to rehash before dirtying a slot. + if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum { + shard.dirtyMu.Unlock() + shard.rehash(slots) + goto retry + } + slot.key = key + atomic.StorePointer(&slot.val, newVal) // transitions slot to full + shard.dirty++ + atomic.AddUintptr(&shard.count, 1) + shard.dirtyMu.Unlock() + return nil + } + // Raced with another store; the slot is no longer empty. Continue + // with the new value of slotVal since we may have raced with + // another store of key. + shard.dirtyMu.Unlock() + } + if slotVal == evacuated() { + // Racing with rehashing. + goto retry + } + if slot.key == key { + // We're reusing an existing slot, so rehashing isn't necessary. + for { + if (compare && oldVal != slotVal) || newVal == slotVal { + if slotVal == tombstone() { + return nil + } + return (*Value)(slotVal) + } + if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) { + if slotVal == tombstone() { + atomic.AddUintptr(&shard.count, 1) + return nil + } + if newVal == tombstone() { + atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */) + } + return (*Value)(slotVal) + } + slotVal = atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + } + } + // This produces a triangular number sequence of offsets from the + // initially-probed position. + i = (i + inc) & mask + inc++ + } +} + +// rehash is marked nosplit to avoid preemption during table copying. +//go:nosplit +func (shard *apmShard) rehash(oldSlots unsafe.Pointer) { + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + + if shard.slots != oldSlots { + // Raced with another call to rehash(). + return + } + + // Determine the size of the new table. Constraints: + // + // * The size of the table must be a power of two to ensure that every slot + // is visitable by every probe sequence under quadratic probing with + // triangular numbers. + // + // * The size of the table cannot decrease because even if shard.count is + // currently smaller than shard.dirty, concurrent stores that reuse + // existing slots can drive shard.count back up to a maximum of + // shard.dirty. + newSize := uintptr(8) // arbitrary initial size + if oldSlots != nil { + oldSize := shard.mask + 1 + newSize = oldSize + if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum { + newSize *= 2 + } + } + + // Allocate the new table. + newSlotsSlice := make([]apmSlot, newSize) + newSlotsReflect := (*reflect.SliceHeader)(unsafe.Pointer(&newSlotsSlice)) + newSlots := unsafe.Pointer(newSlotsReflect.Data) + runtime.KeepAlive(newSlotsSlice) + newMask := newSize - 1 + + // Start a writer critical section now so that racing users of the old + // table that observe evacuated() wait for the new table. (But lock dirtyMu + // first since doing so may block, which we don't want to do during the + // writer critical section.) + shard.dirtyMu.Lock() + shard.seq.BeginWrite() + + if oldSlots != nil { + realCount := uintptr(0) + // Copy old entries to the new table. + oldMask := shard.mask + for i := uintptr(0); i <= oldMask; i++ { + oldSlot := apmSlotAt(oldSlots, i) + val := atomic.SwapPointer(&oldSlot.val, evacuated()) + if val == nil || val == tombstone() { + continue + } + hash := hasher.Hash(oldSlot.key) + j := hash & newMask + inc := uintptr(1) + for { + newSlot := apmSlotAt(newSlots, j) + if newSlot.val == nil { + newSlot.val = val + newSlot.key = oldSlot.key + break + } + j = (j + inc) & newMask + inc++ + } + realCount++ + } + // Update dirty to reflect that tombstones were not copied to the new + // table. Use realCount since a concurrent mutator may not have updated + // shard.count yet. + shard.dirty = realCount + } + + // Switch to the new table. + atomic.StorePointer(&shard.slots, newSlots) + atomic.StoreUintptr(&shard.mask, newMask) + + shard.seq.EndWrite() + shard.dirtyMu.Unlock() +} + +// Range invokes f on each Key-Value pair stored in m. If any call to f returns +// false, Range stops iteration and returns. +// +// Range does not necessarily correspond to any consistent snapshot of the +// Map's contents: no Key will be visited more than once, but if the Value for +// any Key is stored or deleted concurrently, Range may reflect any mapping for +// that Key from any point during the Range call. +// +// f must not call other methods on m. +func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + if !shard.doRange(f) { + return + } + } +} + +func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool { + // We have to lock rehashMu because if we handled races with rehashing by + // retrying, f could see the same key twice. + shard.rehashMu.Lock() + defer shard.rehashMu.Unlock() + slots := shard.slots + if slots == nil { + return true + } + mask := shard.mask + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return false + } + } + return true +} + +// RangeRepeatable is like Range, but: +// +// * RangeRepeatable may visit the same Key multiple times in the presence of +// concurrent mutators, possibly passing different Values to f in different +// calls. +// +// * It is safe for f to call other methods on m. +func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) { + for si := 0; si < len(m.shards); si++ { + shard := &m.shards[si] + + retry: + epoch := shard.seq.BeginRead() + slots := atomic.LoadPointer(&shard.slots) + mask := atomic.LoadUintptr(&shard.mask) + if !shard.seq.ReadOk(epoch) { + goto retry + } + if slots == nil { + continue + } + + for i := uintptr(0); i <= mask; i++ { + slot := apmSlotAt(slots, i) + slotVal := atomic.LoadPointer(&slot.val) + if slotVal == evacuated() { + goto retry + } + if slotVal == nil || slotVal == tombstone() { + continue + } + if !f(slot.key, (*Value)(slotVal)) { + return + } + } + } +} diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/generic_seqatomic_unsafe.go index 780f3b8f8..82b676abf 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/generic_seqatomic_unsafe.go @@ -3,9 +3,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package template doesn't exist. This file must be instantiated using the +// Package seqatomic doesn't exist. This file must be instantiated using the // go_template_instance rule in tools/go_generics/defs.bzl. -package template +package seqatomic import ( "unsafe" diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 7ad6a4434..e925e2e5b 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -11,6 +11,8 @@ package sync import ( + "fmt" + "reflect" "unsafe" ) @@ -61,6 +63,57 @@ const ( TraceEvGoBlockSelect byte = 24 ) +// Rand32 returns a non-cryptographically-secure random uint32. +func Rand32() uint32 { + return fastrand() +} + +// Rand64 returns a non-cryptographically-secure random uint64. +func Rand64() uint64 { + return uint64(fastrand())<<32 | uint64(fastrand()) +} + +//go:linkname fastrand runtime.fastrand +func fastrand() uint32 + +// RandUintptr returns a non-cryptographically-secure random uintptr. +func RandUintptr() uintptr { + if unsafe.Sizeof(uintptr(0)) == 4 { + return uintptr(Rand32()) + } + return uintptr(Rand64()) +} + +// MapKeyHasher returns a hash function for pointers of m's key type. +// +// Preconditions: m must be a map. +func MapKeyHasher(m interface{}) func(unsafe.Pointer, uintptr) uintptr { + if rtyp := reflect.TypeOf(m); rtyp.Kind() != reflect.Map { + panic(fmt.Sprintf("sync.MapKeyHasher: m is %v, not map", rtyp)) + } + mtyp := *(**maptype)(unsafe.Pointer(&m)) + return mtyp.hasher +} + +type maptype struct { + size uintptr + ptrdata uintptr + hash uint32 + tflag uint8 + align uint8 + fieldAlign uint8 + kind uint8 + equal func(unsafe.Pointer, unsafe.Pointer) bool + gcdata *byte + str int32 + ptrToThis int32 + key unsafe.Pointer + elem unsafe.Pointer + bucket unsafe.Pointer + hasher func(unsafe.Pointer, uintptr) uintptr + // more fields +} + // These functions are only used within the sync package. //go:linkname semacquire sync.runtime_Semacquire diff --git a/pkg/sync/rwmutex_test.go b/pkg/sync/rwmutex_test.go index ce667e825..5ca96d12b 100644 --- a/pkg/sync/rwmutex_test.go +++ b/pkg/sync/rwmutex_test.go @@ -102,7 +102,7 @@ func downgradingWriter(rwm *RWMutex, numIterations int, activity *int32, cdone c } for i := 0; i < 100; i++ { } - n = atomic.AddInt32(activity, -1) + atomic.AddInt32(activity, -1) rwm.RUnlock() } cdone <- true diff --git a/pkg/syserr/host_linux.go b/pkg/syserr/host_linux.go index fc6ef60a1..77faa3670 100644 --- a/pkg/syserr/host_linux.go +++ b/pkg/syserr/host_linux.go @@ -32,7 +32,7 @@ var linuxHostTranslations [maxErrno]linuxHostTranslation // FromHost translates a syscall.Errno to a corresponding Error value. func FromHost(err syscall.Errno) *Error { - if err < 0 || int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { + if int(err) >= len(linuxHostTranslations) || !linuxHostTranslations[err].ok { panic(fmt.Sprintf("unknown host errno %q (%d)", err.Error(), err)) } return linuxHostTranslations[err].err diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 5ae10939d..77c3c110c 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -15,6 +15,8 @@ package syserr import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,45 +50,56 @@ var ( ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) ) -var netstackErrorTranslations = map[*tcpip.Error]*Error{ - tcpip.ErrUnknownProtocol: ErrUnknownProtocol, - tcpip.ErrUnknownNICID: ErrUnknownNICID, - tcpip.ErrUnknownDevice: ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID: ErrDuplicateNICID, - tcpip.ErrDuplicateAddress: ErrDuplicateAddress, - tcpip.ErrNoRoute: ErrNoRoute, - tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound: ErrAlreadyBound, - tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected: ErrAlreadyConnected, - tcpip.ErrNoPortAvailable: ErrNoPortAvailable, - tcpip.ErrPortInUse: ErrPortInUse, - tcpip.ErrBadLocalAddress: ErrBadLocalAddress, - tcpip.ErrClosedForSend: ErrClosedForSend, - tcpip.ErrClosedForReceive: ErrClosedForReceive, - tcpip.ErrWouldBlock: ErrWouldBlock, - tcpip.ErrConnectionRefused: ErrConnectionRefused, - tcpip.ErrTimeout: ErrTimeout, - tcpip.ErrAborted: ErrAborted, - tcpip.ErrConnectStarted: ErrConnectStarted, - tcpip.ErrDestinationRequired: ErrDestinationRequired, - tcpip.ErrNotSupported: ErrNotSupported, - tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, - tcpip.ErrNotConnected: ErrNotConnected, - tcpip.ErrConnectionReset: ErrConnectionReset, - tcpip.ErrConnectionAborted: ErrConnectionAborted, - tcpip.ErrNoSuchFile: ErrNoSuchFile, - tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress: ErrHostDown, - tcpip.ErrBadAddress: ErrBadAddress, - tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, - tcpip.ErrMessageTooLong: ErrMessageTooLong, - tcpip.ErrNoBufferSpace: ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, - tcpip.ErrNotPermitted: ErrNotPermittedNet, - tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported, +var netstackErrorTranslations map[string]*Error + +func addErrMapping(tcpipErr *tcpip.Error, netstackErr *Error) { + key := tcpipErr.String() + if _, ok := netstackErrorTranslations[key]; ok { + panic(fmt.Sprintf("duplicate error key: %s", key)) + } + netstackErrorTranslations[key] = netstackErr +} + +func init() { + netstackErrorTranslations = make(map[string]*Error) + addErrMapping(tcpip.ErrUnknownProtocol, ErrUnknownProtocol) + addErrMapping(tcpip.ErrUnknownNICID, ErrUnknownNICID) + addErrMapping(tcpip.ErrUnknownDevice, ErrUnknownDevice) + addErrMapping(tcpip.ErrUnknownProtocolOption, ErrUnknownProtocolOption) + addErrMapping(tcpip.ErrDuplicateNICID, ErrDuplicateNICID) + addErrMapping(tcpip.ErrDuplicateAddress, ErrDuplicateAddress) + addErrMapping(tcpip.ErrNoRoute, ErrNoRoute) + addErrMapping(tcpip.ErrBadLinkEndpoint, ErrBadLinkEndpoint) + addErrMapping(tcpip.ErrAlreadyBound, ErrAlreadyBound) + addErrMapping(tcpip.ErrInvalidEndpointState, ErrInvalidEndpointState) + addErrMapping(tcpip.ErrAlreadyConnecting, ErrAlreadyConnecting) + addErrMapping(tcpip.ErrAlreadyConnected, ErrAlreadyConnected) + addErrMapping(tcpip.ErrNoPortAvailable, ErrNoPortAvailable) + addErrMapping(tcpip.ErrPortInUse, ErrPortInUse) + addErrMapping(tcpip.ErrBadLocalAddress, ErrBadLocalAddress) + addErrMapping(tcpip.ErrClosedForSend, ErrClosedForSend) + addErrMapping(tcpip.ErrClosedForReceive, ErrClosedForReceive) + addErrMapping(tcpip.ErrWouldBlock, ErrWouldBlock) + addErrMapping(tcpip.ErrConnectionRefused, ErrConnectionRefused) + addErrMapping(tcpip.ErrTimeout, ErrTimeout) + addErrMapping(tcpip.ErrAborted, ErrAborted) + addErrMapping(tcpip.ErrConnectStarted, ErrConnectStarted) + addErrMapping(tcpip.ErrDestinationRequired, ErrDestinationRequired) + addErrMapping(tcpip.ErrNotSupported, ErrNotSupported) + addErrMapping(tcpip.ErrQueueSizeNotSupported, ErrQueueSizeNotSupported) + addErrMapping(tcpip.ErrNotConnected, ErrNotConnected) + addErrMapping(tcpip.ErrConnectionReset, ErrConnectionReset) + addErrMapping(tcpip.ErrConnectionAborted, ErrConnectionAborted) + addErrMapping(tcpip.ErrNoSuchFile, ErrNoSuchFile) + addErrMapping(tcpip.ErrInvalidOptionValue, ErrInvalidOptionValue) + addErrMapping(tcpip.ErrNoLinkAddress, ErrHostDown) + addErrMapping(tcpip.ErrBadAddress, ErrBadAddress) + addErrMapping(tcpip.ErrNetworkUnreachable, ErrNetworkUnreachable) + addErrMapping(tcpip.ErrMessageTooLong, ErrMessageTooLong) + addErrMapping(tcpip.ErrNoBufferSpace, ErrNoBufferSpace) + addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) + addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) + addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) } // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -95,7 +108,7 @@ func TranslateNetstackError(err *tcpip.Error) *Error { if err == nil { return nil } - se, ok := netstackErrorTranslations[err] + se, ok := netstackErrorTranslations[err.String()] if !ok { panic("Unknown error: " + err.String()) } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 27f96a3ac..89b765f1b 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -1,10 +1,24 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "sock_err_list", + out = "sock_err_list.go", + package = "tcpip", + prefix = "sockError", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SockError", + "Linker": "*SockError", + }, +) + go_library( name = "tcpip", srcs = [ + "sock_err_list.go", "socketops.go", "tcpip.go", "time_unsafe.go", diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d3ae56ac6..91971b687 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -117,6 +117,10 @@ func TTL(ttl uint8) NetworkChecker { v = ip.TTL() case header.IPv6: v = ip.HopLimit() + case *ipv6HeaderWithExtHdr: + v = ip.HopLimit() + default: + t.Fatalf("unrecognized header type %T for TTL evaluation", ip) } if v != ttl { t.Fatalf("Bad TTL, got = %d, want = %d", v, ttl) @@ -321,6 +325,19 @@ func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { } } +// ReceiveOriginalDstAddr creates a checker that checks the OriginalDstAddress +// field in ControlMessages. +func ReceiveOriginalDstAddr(want tcpip.FullAddress) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasOriginalDstAddress { + t.Errorf("got cm.HasOriginalDstAddress = %t, want = true", cm.HasOriginalDstAddress) + } else if diff := cmp.Diff(want, cm.OriginalDstAddress); diff != "" { + t.Errorf("OriginalDstAddress mismatch (-want +got):\n%s", diff) + } + } +} + // TOS creates a checker that checks the TOS field. func TOS(tos uint8, label uint32) NetworkChecker { return func(t *testing.T, h []header.Network) { @@ -1400,3 +1417,189 @@ func IGMPGroupAddress(want tcpip.Address) TransportChecker { } } } + +// IPv6ExtHdrChecker is a function to check an extension header. +type IPv6ExtHdrChecker func(*testing.T, header.IPv6PayloadHeader) + +// IPv6WithExtHdr is like IPv6 but allows IPv6 packets with extension headers. +func IPv6WithExtHdr(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("not a valid IPv6 packet") + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var rawPayloadHeader header.IPv6RawPayloadHeader + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, true, _)", h, done) + return + } + r, ok := h.(header.IPv6RawPayloadHeader) + if ok { + rawPayloadHeader = r + break + } + } + + networkHeader := ipv6HeaderWithExtHdr{ + IPv6: ipv6, + transport: tcpip.TransportProtocolNumber(rawPayloadHeader.Identifier), + payload: rawPayloadHeader.Buf.ToView(), + } + + for _, checker := range checkers { + checker(t, []header.Network{&networkHeader}) + } +} + +// IPv6ExtHdr checks for the presence of extension headers. +// +// All the extension headers in headers will be checked exhaustively in the +// order provided. +func IPv6ExtHdr(headers ...IPv6ExtHdrChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + extHdrs, ok := h[0].(*ipv6HeaderWithExtHdr) + if !ok { + t.Errorf("got network header = %T, want = *ipv6HeaderWithExtHdr", h[0]) + return + } + + payloadIterator := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(extHdrs.IPv6.NextHeader()), + buffer.View(extHdrs.IPv6.Payload()).ToVectorisedView(), + ) + + for _, check := range headers { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, false, _)", h, done) + return + } + check(t, h) + } + // Validate we consumed all headers. + // + // The next one over should be a raw payload and then iterator should + // terminate. + wantDone := false + for { + h, done, err := payloadIterator.Next() + if err != nil { + t.Errorf("payloadIterator.Next(): %s", err) + return + } + if done != wantDone { + t.Errorf("got payloadIterator.Next() = (%T, %t, _), want = (_, %t, _)", h, done, wantDone) + return + } + if done { + break + } + if _, ok := h.(header.IPv6RawPayloadHeader); !ok { + t.Errorf("got payloadIterator.Next() = (%T, _, _), want = (header.IPv6RawPayloadHeader, _, _)", h) + continue + } + wantDone = true + } + } +} + +var _ header.Network = (*ipv6HeaderWithExtHdr)(nil) + +// ipv6HeaderWithExtHdr provides a header.Network implementation that takes +// extension headers into consideration, which is not the case with vanilla +// header.IPv6. +type ipv6HeaderWithExtHdr struct { + header.IPv6 + transport tcpip.TransportProtocolNumber + payload []byte +} + +// TransportProtocol implements header.Network. +func (h *ipv6HeaderWithExtHdr) TransportProtocol() tcpip.TransportProtocolNumber { + return h.transport +} + +// Payload implements header.Network. +func (h *ipv6HeaderWithExtHdr) Payload() []byte { + return h.payload +} + +// IPv6ExtHdrOptionChecker is a function to check an extension header option. +type IPv6ExtHdrOptionChecker func(*testing.T, header.IPv6ExtHdrOption) + +// IPv6HopByHopExtensionHeader checks the extension header is a Hop by Hop +// extension header and validates the containing options with checkers. +// +// checkers must exhaustively contain all the expected options. +func IPv6HopByHopExtensionHeader(checkers ...IPv6ExtHdrOptionChecker) IPv6ExtHdrChecker { + return func(t *testing.T, payloadHeader header.IPv6PayloadHeader) { + t.Helper() + + hbh, ok := payloadHeader.(header.IPv6HopByHopOptionsExtHdr) + if !ok { + t.Errorf("unexpected IPv6 payload header, got = %T, want = header.IPv6HopByHopOptionsExtHdr", payloadHeader) + return + } + optionsIterator := hbh.Iter() + for _, f := range checkers { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + f(t, opt) + } + // Validate all options were consumed. + for { + opt, done, err := optionsIterator.Next() + if err != nil { + t.Errorf("optionsIterator.Next(): %s", err) + return + } + if !done { + t.Errorf("got optionsIterator.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + if done { + break + } + } + } +} + +// IPv6RouterAlert validates that an extension header option is the RouterAlert +// option and matches on its value. +func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { + return func(t *testing.T, opt header.IPv6ExtHdrOption) { + routerAlert, ok := opt.(*header.IPv6RouterAlertOption) + if !ok { + t.Errorf("unexpected extension header option, got = %T, want = header.IPv6RouterAlertOption", opt) + return + } + if routerAlert.Value != want { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, want) + } + } +} diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 309403482..5ab20ee86 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -19,6 +19,7 @@ package header_test import ( "fmt" "math/rand" + "sync" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -169,3 +170,96 @@ func BenchmarkChecksum(b *testing.B) { } } } + +func testICMPChecksum(t *testing.T, headerChecksum func() uint16, icmpChecksum func() uint16, want uint16, pktStr string) { + // icmpChecksum should not do any modifications of the header to + // calculate its checksum. Let's call it from a few go-routines and the + // race detector will trigger a warning if there are any concurrent + // read/write accesses. + + const concurrency = 5 + start := make(chan int) + ready := make(chan bool, concurrency) + var wg sync.WaitGroup + wg.Add(concurrency) + defer wg.Wait() + + for i := 0; i < concurrency; i++ { + go func() { + defer wg.Done() + + ready <- true + <-start + + if got := headerChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + if got := icmpChecksum(); want != got { + t.Errorf("new checksum for %s does not match old got: %x, want: %x", pktStr, got, want) + } + }() + } + for i := 0; i < concurrency; i++ { + <-ready + } + close(start) +} + +func TestICMPv4Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:5]), + buffer.NewViewFromBytes(buf[5:]), + }) + + want := header.Checksum(vv.ToView(), 0) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv4Checksum(h, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} + +func TestICMPv6Checksum(t *testing.T) { + rnd := rand.New(rand.NewSource(42)) + + h := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize)) + if _, err := rnd.Read(h); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + h.SetChecksum(0) + + buf := make([]byte, 13) + if _, err := rnd.Read(buf); err != nil { + t.Fatalf("rnd.Read failed: %v", err) + } + vv := buffer.NewVectorisedView(len(buf), []buffer.View{ + buffer.NewViewFromBytes(buf[:7]), + buffer.NewViewFromBytes(buf[7:10]), + buffer.NewViewFromBytes(buf[10:]), + }) + + dst := header.IPv6Loopback + src := header.IPv6Loopback + + want := header.PseudoHeaderChecksum(header.ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + want = header.Checksum(vv.ToView(), want) + want = ^header.Checksum(h, want) + h.SetChecksum(want) + + testICMPChecksum(t, h.Checksum, func() uint16 { + return header.ICMPv6Checksum(h, src, dst, vv) + }, want, fmt.Sprintf("header: {% x} data {% x}", h, vv.ToView())) +} diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 2f13dea6a..5f9b8e9e2 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "fmt" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -199,17 +200,24 @@ func (b ICMPv4) SetSequence(sequence uint16) { // ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header, // and payload. func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := uint16(0) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } + xsum := ChecksumVV(vv, 0) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 + return ^xsum +} - return xsum +// ICMPOriginFromNetProto returns the appropriate SockErrOrigin to use when +// a packet having a `net` header causing an ICMP error. +func ICMPOriginFromNetProto(net tcpip.NetworkProtocolNumber) tcpip.SockErrOrigin { + switch net { + case IPv4ProtocolNumber: + return tcpip.SockExtErrorOriginICMP + case IPv6ProtocolNumber: + return tcpip.SockExtErrorOriginICMP6 + default: + panic(fmt.Sprintf("unsupported net proto to extract ICMP error origin: %d", net)) + } } diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index 2eef64b4d..eca9750ab 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -265,22 +265,13 @@ func (b ICMPv6) Payload() []byte { // ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header, // IPv6 src/dst addresses and the payload. func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { - // Calculate the IPv6 pseudo-header upper-layer checksum. - xsum := Checksum([]byte(src), 0) - xsum = Checksum([]byte(dst), xsum) - var upperLayerLength [4]byte - binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) - xsum = Checksum(upperLayerLength[:], xsum) - xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum) - for _, v := range vv.Views() { - xsum = Checksum(v, xsum) - } - - // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. - h2, h3 := h[2], h[3] - h[2], h[3] = 0, 0 - xsum = ^Checksum(h, xsum) - h[2], h[3] = h2, h3 - - return xsum + xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, src, dst, uint16(len(h)+vv.Size())) + + xsum = ChecksumVV(vv, xsum) + + // h[2:4] is the checksum itself, skip it to avoid checksumming the checksum. + xsum = Checksum(h[:2], xsum) + xsum = Checksum(h[4:], xsum) + + return ^xsum } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 55d09355a..5580d6a78 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -18,7 +18,6 @@ import ( "crypto/sha256" "encoding/binary" "fmt" - "strings" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -48,11 +47,13 @@ type IPv6Fields struct { // FlowLabel is the "flow label" field of an IPv6 packet. FlowLabel uint32 - // PayloadLength is the "payload length" field of an IPv6 packet. + // PayloadLength is the "payload length" field of an IPv6 packet, including + // the length of all extension headers. PayloadLength uint16 - // NextHeader is the "next header" field of an IPv6 packet. - NextHeader uint8 + // TransportProtocol is the transport layer protocol number. Serialized in the + // last "next header" field of the IPv6 header + extension headers. + TransportProtocol tcpip.TransportProtocolNumber // HopLimit is the "Hop Limit" field of an IPv6 packet. HopLimit uint8 @@ -62,6 +63,9 @@ type IPv6Fields struct { // DstAddr is the "destination ip address" of an IPv6 packet. DstAddr tcpip.Address + + // ExtensionHeaders are the extension headers following the IPv6 header. + ExtensionHeaders IPv6ExtHdrSerializer } // IPv6 represents an ipv6 header stored in a byte array. @@ -148,13 +152,17 @@ const ( // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the // catch-all or wildcard subnet. That is, all IPv6 addresses are considered to // be contained within this subnet. -var IPv6EmptySubnet = func() tcpip.Subnet { - subnet, err := tcpip.NewSubnet(IPv6Any, tcpip.AddressMask(IPv6Any)) - if err != nil { - panic(err) - } - return subnet -}() +var IPv6EmptySubnet = tcpip.AddressWithPrefix{ + Address: IPv6Any, + PrefixLen: 0, +}.Subnet() + +// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined +// by RFC 4291 section 2.5.5. +var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{ + Address: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00", + PrefixLen: 96, +}.Subnet() // IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined // by RFC 4291 section 2.5.6. @@ -253,12 +261,14 @@ func (IPv6) SetChecksum(uint16) { // Encode encodes all the fields of the ipv6 header. func (b IPv6) Encode(i *IPv6Fields) { + extHdr := b[IPv6MinimumSize:] b.SetTOS(i.TrafficClass, i.FlowLabel) b.SetPayloadLength(i.PayloadLength) - b[IPv6NextHeaderOffset] = i.NextHeader b[hopLimit] = i.HopLimit b.SetSourceAddress(i.SrcAddr) b.SetDestinationAddress(i.DstAddr) + nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr) + b[IPv6NextHeaderOffset] = nextHeader } // IsValid performs basic validation on the packet. @@ -286,7 +296,7 @@ func IsV4MappedAddress(addr tcpip.Address) bool { return false } - return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") + return IPv4MappedIPv6Subnet.Contains(addr) } // IsV6MulticastAddress determines if the provided address is an IPv6 @@ -392,17 +402,6 @@ func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope } -// IsV6UniqueLocalAddress determines if the provided address is an IPv6 -// unique-local address (within the prefix FC00::/7). -func IsV6UniqueLocalAddress(addr tcpip.Address) bool { - if len(addr) != IPv6AddressSize { - return false - } - // According to RFC 4193 section 3.1, a unique local address has the prefix - // FC00::/7. - return (addr[0] & 0xfe) == 0xfc -} - // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier // (IID) to buf as outlined by RFC 7217 and returns the extended buffer. // @@ -449,9 +448,6 @@ const ( // LinkLocalScope indicates a link-local address. LinkLocalScope IPv6AddressScope = iota - // UniqueLocalScope indicates a unique-local address. - UniqueLocalScope - // GlobalScope indicates a global address. GlobalScope ) @@ -469,9 +465,6 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) { case IsV6LinkLocalAddress(addr): return LinkLocalScope, nil - case IsV6UniqueLocalAddress(addr): - return UniqueLocalScope, nil - default: return GlobalScope, nil } diff --git a/pkg/tcpip/header/ipv6_extension_headers.go b/pkg/tcpip/header/ipv6_extension_headers.go index 571eae233..f18981332 100644 --- a/pkg/tcpip/header/ipv6_extension_headers.go +++ b/pkg/tcpip/header/ipv6_extension_headers.go @@ -18,9 +18,12 @@ import ( "bufio" "bytes" "encoding/binary" + "errors" "fmt" "io" + "math" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -75,8 +78,8 @@ const ( // Fragment Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetOffset = 0 - // ipv6FragmentExtHdrFragmentOffsetShift is the least significant bits to - // discard from the Fragment Offset. + // ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment + // Offset field within an IPv6FragmentExtHdr. ipv6FragmentExtHdrFragmentOffsetShift = 3 // ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an @@ -114,6 +117,37 @@ const ( IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8 ) +// padIPv6OptionsLength returns the total length for IPv6 options of length l +// considering the 8-octet alignment as stated in RFC 8200 Section 4.2. +func padIPv6OptionsLength(length int) int { + return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1) +} + +// padIPv6Option fills b with the appropriate padding options depending on its +// length. +func padIPv6Option(b []byte) { + switch len(b) { + case 0: // No padding needed. + case 1: // Pad with Pad1. + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier) + default: // Pad with PadN. + s := b[ipv6ExtHdrOptionPayloadOffset:] + for i := range s { + s[i] = 0 + } + b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier) + b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s)) + } +} + +// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to +// serialize an option at headerOffset with alignment requirements +// [align]n + alignOffset. +func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int { + padLen := headerOffset - alignOffset + return ((padLen + align - 1) & ^(align - 1)) - padLen +} + // IPv6PayloadHeader is implemented by the various headers that can be found // in an IPv6 payload. // @@ -206,29 +240,55 @@ type IPv6ExtHdrOption interface { isIPv6ExtHdrOption() } -// IPv6ExtHdrOptionIndentifier is an IPv6 extension header option identifier. -type IPv6ExtHdrOptionIndentifier uint8 +// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier. +type IPv6ExtHdrOptionIdentifier uint8 const ( // ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that // provides 1 byte padding, as outlined in RFC 8200 section 4.2. - ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 0 + ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0 // ipv6PadBExtHdrOptionIdentifier is the identifier for a padding option that // provides variable length byte padding, as outlined in RFC 8200 section 4.2. - ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIndentifier = 1 + ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1 + + // ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router + // Alert Hop by Hop option as defined in RFC 2711 section 2.1. + ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5 + + // ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header + // option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionTypeOffset = 0 + + // ipv6ExtHdrOptionLengthOffset is the option length offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionLengthOffset = 1 + + // ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension + // header option as defined in RFC 8200 section 4.2. + ipv6ExtHdrOptionPayloadOffset = 2 ) +// ipv6UnknownActionFromIdentifier maps an extension header option's +// identifier's high bits to the action to take when the identifier is unknown. +func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction { + return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) +} + +// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option +// is malformed. +var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option") + // IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension // header option that is unknown by the parsing utilities. type IPv6UnknownExtHdrOption struct { - Identifier IPv6ExtHdrOptionIndentifier + Identifier IPv6ExtHdrOptionIdentifier Data []byte } // UnknownAction implements IPv6OptionUnknownAction.UnknownAction. func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction { - return IPv6OptionUnknownAction((o.Identifier & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift) + return ipv6UnknownActionFromIdentifier(o.Identifier) } // isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption. @@ -251,7 +311,7 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error // options buffer has been exhausted and we are done iterating. return nil, true, nil } - id := IPv6ExtHdrOptionIndentifier(temp) + id := IPv6ExtHdrOptionIdentifier(temp) // If the option identifier indicates the option is a Pad1 option, then we // know the option does not have Length and Data fields. End processing of @@ -294,6 +354,19 @@ func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error panic(fmt.Sprintf("error when skipping PadN (N = %d) option's data bytes: %s", length, err)) } continue + case ipv6RouterAlertHopByHopOptionIdentifier: + var routerAlertValue [ipv6RouterAlertPayloadLength]byte + if n, err := io.ReadFull(&i.reader, routerAlertValue[:]); err != nil { + switch err { + case io.EOF, io.ErrUnexpectedEOF: + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + default: + return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err) + } + } else if n != int(length) { + return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption) + } + return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil default: bytes := make([]byte, length) if n, err := io.ReadFull(&i.reader, bytes); err != nil { @@ -609,3 +682,248 @@ func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IP return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), bytes, nil } + +// IPv6SerializableExtHdr provides serialization for IPv6 extension +// headers. +type IPv6SerializableExtHdr interface { + // identifier returns the assigned IPv6 header identifier for this extension + // header. + identifier() IPv6ExtensionHeaderIdentifier + + // length returns the total serialized length in bytes of this extension + // header, including the common next header and length fields. + length() int + + // serializeInto serializes the receiver into the provided byte + // buffer and with the provided nextHeader value. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto returns the number of bytes that was used to serialize the + // receiver. Implementers must only use the number of bytes required to + // serialize the receiver. Callers MAY provide a larger buffer than required + // to serialize into. + serializeInto(nextHeader uint8, b []byte) int +} + +var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil) + +// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop +// options extension header. +type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption + +const ( + // ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field + // in a hop by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrNextHeaderOffset = 0 + + // ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop + // by hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrLengthOffset = 1 + + // ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by + // hop extension header as defined in RFC 8200 section 4.3. + ipv6HopByHopExtHdrOptionsOffset = 2 + + // ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet + // words in a hop by hop extension header's length field, as stated in RFC + // 8200 section 4.3: + // Length of the Hop-by-Hop Options header in 8-octet units, + // not including the first 8 octets. + ipv6HopByHopExtHdrUnaccountedLenWords = 1 +) + +// identifier implements IPv6SerializableExtHdr. +func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6HopByHopOptionsExtHdrIdentifier +} + +// length implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) length() int { + var total int + for _, opt := range h { + align, alignOffset := opt.alignment() + total += ipv6OptionsAlignmentPadding(total, align, alignOffset) + total += ipv6ExtHdrOptionPayloadOffset + int(opt.length()) + } + // Account for next header and total length fields and add padding. + return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total) +} + +// serializeInto implements IPv6SerializableExtHdr. +func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int { + optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:] + totalLength := ipv6HopByHopExtHdrOptionsOffset + for _, opt := range h { + // Calculate alignment requirements and pad buffer if necessary. + align, alignOffset := opt.alignment() + padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset) + if padLen != 0 { + padIPv6Option(optBuffer[:padLen]) + totalLength += padLen + optBuffer = optBuffer[padLen:] + } + + l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:]) + optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier()) + optBuffer[ipv6ExtHdrOptionLengthOffset] = l + l += ipv6ExtHdrOptionPayloadOffset + totalLength += int(l) + optBuffer = optBuffer[l:] + } + padded := padIPv6OptionsLength(totalLength) + if padded != totalLength { + padIPv6Option(optBuffer[:padded-totalLength]) + totalLength = padded + } + wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords + if wordsLen > math.MaxUint8 { + panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen)) + } + b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader + b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen) + return totalLength +} + +// IPv6SerializableHopByHopOption provides serialization for hop by hop options. +type IPv6SerializableHopByHopOption interface { + // identifier returns the option identifier of this Hop by Hop option. + identifier() IPv6ExtHdrOptionIdentifier + + // length returns the *payload* size of the option (not considering the type + // and length fields). + length() uint8 + + // alignment returns the alignment requirements from this option. + // + // Alignment requirements take the form [align]n + offset as specified in + // RFC 8200 section 4.2. The alignment requirement is on the offset between + // the option type byte and the start of the hop by hop header. + // + // align must be a power of 2. + alignment() (align int, offset int) + + // serializeInto serializes the receiver into the provided byte + // buffer. + // + // Note, the caller MUST provide a byte buffer with size of at least + // length. Implementers of this function may assume that the byte buffer + // is of sufficient size. serializeInto MAY panic if the provided byte + // buffer is not of sufficient size. + // + // serializeInto will return the number of bytes that was used to + // serialize the receiver. Implementers must only use the number of + // bytes required to serialize the receiver. Callers MAY provide a + // larger buffer than required to serialize into. + serializeInto([]byte) uint8 +} + +var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil) + +// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in +// RFC 2711 section 2.1. +type IPv6RouterAlertOption struct { + Value IPv6RouterAlertValue +} + +// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option. +type IPv6RouterAlertValue uint16 + +const ( + // IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener + // Discovery message as defined in RFC 2711 section 2.1. + IPv6RouterAlertMLD IPv6RouterAlertValue = 0 + // IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as + // defined in RFC 2711 section 2.1. + IPv6RouterAlertRSVP IPv6RouterAlertValue = 1 + // IPv6RouterAlertActiveNetworks indicates a datagram containing an Active + // Networks message as defined in RFC 2711 section 2.1. + IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2 + + // ipv6RouterAlertPayloadLength is the length of the Router Alert payload + // as defined in RFC 2711. + ipv6RouterAlertPayloadLength = 2 + + // ipv6RouterAlertAlignmentRequirement is the alignment requirement for the + // Router Alert option defined as 2n+0 in RFC 2711. + ipv6RouterAlertAlignmentRequirement = 2 + + // ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset + // requirement for the Router Alert option defined as 2n+0 in RFC 2711 section + // 2.1. + ipv6RouterAlertAlignmentOffsetRequirement = 0 +) + +// UnknownAction implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction { + return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier) +} + +// isIPv6ExtHdrOption implements IPv6ExtHdrOption. +func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {} + +// identifier implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier { + return ipv6RouterAlertHopByHopOptionIdentifier +} + +// length implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) length() uint8 { + return ipv6RouterAlertPayloadLength +} + +// alignment implements IPv6SerializableHopByHopOption. +func (*IPv6RouterAlertOption) alignment() (int, int) { + // From RFC 2711 section 2.1: + // Alignment requirement: 2n+0. + return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 { + binary.BigEndian.PutUint16(b, uint16(o.Value)) + return ipv6RouterAlertPayloadLength +} + +// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers. +type IPv6ExtHdrSerializer []IPv6SerializableExtHdr + +// Serialize serializes the provided list of IPv6 extension headers into b. +// +// Note, b must be of sufficient size to hold all the headers in s. See +// IPv6ExtHdrSerializer.Length for details on the getting the total size of a +// serialized IPv6ExtHdrSerializer. +// +// Serialize may panic if b is not of sufficient size to hold all the options +// in s. +// +// Serialize takes the transportProtocol value to be used as the last extension +// header's Next Header value and returns the header identifier of the first +// serialized extension header and the total serialized length. +func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) { + nextHeader := uint8(transportProtocol) + if len(s) == 0 { + return nextHeader, 0 + } + var totalLength int + for i, h := range s[:len(s)-1] { + length := h.serializeInto(uint8(s[i+1].identifier()), b) + b = b[length:] + totalLength += length + } + totalLength += s[len(s)-1].serializeInto(nextHeader, b) + return uint8(s[0].identifier()), totalLength +} + +// Length returns the total number of bytes required to serialize the extension +// headers. +func (s IPv6ExtHdrSerializer) Length() int { + var totalLength int + for _, h := range s { + totalLength += h.length() + } + return totalLength +} diff --git a/pkg/tcpip/header/ipv6_extension_headers_test.go b/pkg/tcpip/header/ipv6_extension_headers_test.go index ab20c5f37..65adc6250 100644 --- a/pkg/tcpip/header/ipv6_extension_headers_test.go +++ b/pkg/tcpip/header/ipv6_extension_headers_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -59,7 +60,7 @@ func (a IPv6DestinationOptionsExtHdr) Equal(b IPv6DestinationOptionsExtHdr) bool func TestIPv6UnknownExtHdrOption(t *testing.T) { tests := []struct { name string - identifier IPv6ExtHdrOptionIndentifier + identifier IPv6ExtHdrOptionIdentifier expectedUnknownAction IPv6OptionUnknownAction }{ { @@ -211,6 +212,31 @@ func TestIPv6OptionsExtHdrIterErr(t *testing.T) { bytes: []byte{1, 3}, err: io.ErrUnexpectedEOF, }, + { + name: "Router alert without data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with partial data and Pad1", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1, 1, 0}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with extra data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 3, 1, 2, 3}, + err: ErrMalformedIPv6ExtHdrOption, + }, + { + name: "Router alert with missing data", + bytes: []byte{byte(ipv6RouterAlertHopByHopOptionIdentifier), 1}, + err: io.ErrUnexpectedEOF, + }, } check := func(t *testing.T, it IPv6OptionsExtHdrOptionsIterator, expectedErr error) { @@ -990,3 +1016,331 @@ func TestIPv6ExtHdrIter(t *testing.T) { }) } } + +var _ IPv6SerializableHopByHopOption = (*dummyHbHOptionSerializer)(nil) + +// dummyHbHOptionSerializer provides a generic implementation of +// IPv6SerializableHopByHopOption for use in tests. +type dummyHbHOptionSerializer struct { + id IPv6ExtHdrOptionIdentifier + payload []byte + align int + alignOffset int +} + +// identifier implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) identifier() IPv6ExtHdrOptionIdentifier { + return s.id +} + +// length implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) length() uint8 { + return uint8(len(s.payload)) +} + +// alignment implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) alignment() (int, int) { + align := 1 + if s.align != 0 { + align = s.align + } + return align, s.alignOffset +} + +// serializeInto implements IPv6SerializableHopByHopOption. +func (s *dummyHbHOptionSerializer) serializeInto(b []byte) uint8 { + return uint8(copy(b, s.payload)) +} + +func TestIPv6HopByHopSerializer(t *testing.T) { + validateDummies := func(t *testing.T, serializable IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + dummy, ok := serializable.(*dummyHbHOptionSerializer) + if !ok { + t.Fatalf("got serializable = %T, want = *dummyHbHOptionSerializer", serializable) + } + unknown, ok := deserialized.(*IPv6UnknownExtHdrOption) + if !ok { + t.Fatalf("got deserialized = %T, want = %T", deserialized, &IPv6UnknownExtHdrOption{}) + } + if dummy.id != unknown.Identifier { + t.Errorf("got deserialized identifier = %d, want = %d", unknown.Identifier, dummy.id) + } + if diff := cmp.Diff(dummy.payload, unknown.Data); diff != "" { + t.Errorf("option payload deserialization mismatch (-want +got):\n%s", diff) + } + } + tests := []struct { + name string + nextHeader uint8 + options []IPv6SerializableHopByHopOption + expect []byte + validate func(*testing.T, IPv6SerializableHopByHopOption, IPv6ExtHdrOption) + }{ + { + name: "single option", + nextHeader: 13, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 15, + payload: []byte{9, 8, 7, 6}, + }, + }, + expect: []byte{13, 0, 15, 4, 9, 8, 7, 6}, + validate: validateDummies, + }, + { + name: "short option padN zero", + nextHeader: 88, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5}, + }, + }, + expect: []byte{88, 0, 22, 2, 4, 5, 1, 0}, + validate: validateDummies, + }, + { + name: "short option pad1", + nextHeader: 11, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 33, + payload: []byte{1, 2, 3}, + }, + }, + expect: []byte{11, 0, 33, 3, 1, 2, 3, 0}, + validate: validateDummies, + }, + { + name: "long option padN", + nextHeader: 55, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 77, + payload: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + }, + }, + expect: []byte{55, 1, 77, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 22, 3, 4, 5, 6, 1, 2, 0, 0}, + validate: validateDummies, + }, + { + name: "two options align 2n", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2, 3}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 2, + }, + }, + expect: []byte{33, 1, 11, 3, 1, 2, 3, 0, 22, 3, 4, 5, 6, 1, 1, 0}, + validate: validateDummies, + }, + { + name: "two options align 8n+1", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{ + &dummyHbHOptionSerializer{ + id: 11, + payload: []byte{1, 2}, + }, + &dummyHbHOptionSerializer{ + id: 22, + payload: []byte{4, 5, 6}, + align: 8, + alignOffset: 1, + }, + }, + expect: []byte{33, 1, 11, 2, 1, 2, 1, 1, 0, 22, 3, 4, 5, 6, 1, 0}, + validate: validateDummies, + }, + { + name: "no options", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{}, + expect: []byte{33, 0, 1, 4, 0, 0, 0, 0}, + }, + { + name: "Router Alert", + nextHeader: 33, + options: []IPv6SerializableHopByHopOption{&IPv6RouterAlertOption{Value: IPv6RouterAlertMLD}}, + expect: []byte{33, 0, 5, 2, 0, 0, 1, 0}, + validate: func(t *testing.T, _ IPv6SerializableHopByHopOption, deserialized IPv6ExtHdrOption) { + t.Helper() + routerAlert, ok := deserialized.(*IPv6RouterAlertOption) + if !ok { + t.Fatalf("got deserialized = %T, want = *IPv6RouterAlertOption", deserialized) + } + if routerAlert.Value != IPv6RouterAlertMLD { + t.Errorf("got routerAlert.Value = %d, want = %d", routerAlert.Value, IPv6RouterAlertMLD) + } + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6SerializableHopByHopExtHdr(test.options) + length := s.length() + if length != len(test.expect) { + t.Fatalf("got s.length() = %d, want = %d", length, len(test.expect)) + } + b := make([]byte, length) + for i := range b { + // Fill the buffer with ones to ensure all padding is correctly set. + b[i] = 0xFF + } + if got := s.serializeInto(test.nextHeader, b); got != length { + t.Fatalf("got s.serializeInto(..) = %d, want = %d", got, length) + } + if diff := cmp.Diff(test.expect, b); diff != "" { + t.Fatalf("serialization mismatch (-want +got):\n%s", diff) + } + + // Deserialize the options and verify them. + optLen := (b[ipv6HopByHopExtHdrLengthOffset] + ipv6HopByHopExtHdrUnaccountedLenWords) * ipv6ExtHdrLenBytesPerUnit + iter := ipv6OptionsExtHdr(b[ipv6HopByHopExtHdrOptionsOffset:optLen]).Iter() + for _, testOpt := range test.options { + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, false, _)", opt, done) + } + test.validate(t, testOpt, opt) + } + opt, done, err := iter.Next() + if err != nil { + t.Fatalf("iter.Next(): %s", err) + } + if !done { + t.Fatalf("got iter.Next() = (%T, %t, _), want = (_, true, _)", opt, done) + } + }) + } +} + +var _ IPv6SerializableExtHdr = (*dummyIPv6ExtHdrSerializer)(nil) + +// dummyIPv6ExtHdrSerializer provides a generic implementation of +// IPv6SerializableExtHdr for use in tests. +// +// The dummy header always carries the nextHeader value in the first byte. +type dummyIPv6ExtHdrSerializer struct { + id IPv6ExtensionHeaderIdentifier + headerContents []byte +} + +// identifier implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) identifier() IPv6ExtensionHeaderIdentifier { + return s.id +} + +// length implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) length() int { + return len(s.headerContents) + 1 +} + +// serializeInto implements IPv6SerializableExtHdr. +func (s *dummyIPv6ExtHdrSerializer) serializeInto(nextHeader uint8, b []byte) int { + b[0] = nextHeader + return copy(b[1:], s.headerContents) + 1 +} + +func TestIPv6ExtHdrSerializer(t *testing.T) { + tests := []struct { + name string + headers []IPv6SerializableExtHdr + nextHeader tcpip.TransportProtocolNumber + expectSerialized []byte + expectNextHeader uint8 + }{ + { + name: "one header", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 15, + headerContents: []byte{1, 2, 3, 4}, + }, + }, + nextHeader: TCPProtocolNumber, + expectSerialized: []byte{byte(TCPProtocolNumber), 1, 2, 3, 4}, + expectNextHeader: 15, + }, + { + name: "two headers", + headers: []IPv6SerializableExtHdr{ + &dummyIPv6ExtHdrSerializer{ + id: 22, + headerContents: []byte{1, 2, 3}, + }, + &dummyIPv6ExtHdrSerializer{ + id: 23, + headerContents: []byte{4, 5, 6}, + }, + }, + nextHeader: ICMPv6ProtocolNumber, + expectSerialized: []byte{ + 23, 1, 2, 3, + byte(ICMPv6ProtocolNumber), 4, 5, 6, + }, + expectNextHeader: 22, + }, + { + name: "no headers", + headers: []IPv6SerializableExtHdr{}, + nextHeader: UDPProtocolNumber, + expectSerialized: []byte{}, + expectNextHeader: byte(UDPProtocolNumber), + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := IPv6ExtHdrSerializer(test.headers) + l := s.Length() + if got, want := l, len(test.expectSerialized); got != want { + t.Fatalf("got serialized length = %d, want = %d", got, want) + } + b := make([]byte, l) + for i := range b { + // Fill the buffer with garbage to make sure we're writing to all bytes. + b[i] = 0xFF + } + nextHeader, serializedLen := s.Serialize(test.nextHeader, b) + if serializedLen != len(test.expectSerialized) || nextHeader != test.expectNextHeader { + t.Errorf( + "got s.Serialize(..) = (%d, %d), want = (%d, %d)", + nextHeader, + serializedLen, + test.expectNextHeader, + len(test.expectSerialized), + ) + } + if diff := cmp.Diff(test.expectSerialized, b); diff != "" { + t.Errorf("serialization mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go index 018555a26..9d09f32eb 100644 --- a/pkg/tcpip/header/ipv6_fragment.go +++ b/pkg/tcpip/header/ipv6_fragment.go @@ -27,12 +27,11 @@ const ( idV6 = 4 ) -// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv6FragmentFields struct { - // NextHeader is the "next header" field of an IPv6 fragment. - NextHeader uint8 +var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil) +// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment +// extension header as defined in RFC 8200 section 4.5. +type IPv6SerializableFragmentExtHdr struct { // FragmentOffset is the "fragment offset" field of an IPv6 fragment. FragmentOffset uint16 @@ -43,6 +42,29 @@ type IPv6FragmentFields struct { Identification uint32 } +// identifier implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier { + return IPv6FragmentHeader +} + +// length implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) length() int { + return IPv6FragmentHeaderSize +} + +// serializeInto implements IPv6SerializableFragmentExtHdr. +func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int { + // Prevent too many bounds checks. + _ = b[IPv6FragmentHeaderSize:] + binary.BigEndian.PutUint32(b[idV6:], h.Identification) + binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift) + b[nextHdrFrag] = nextHeader + if h.M { + b[more] |= ipv6FragmentExtHdrMFlagMask + } + return IPv6FragmentHeaderSize +} + // IPv6Fragment represents an ipv6 fragment header stored in a byte array. // Most of the methods of IPv6Fragment access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. @@ -58,16 +80,6 @@ const ( IPv6FragmentHeaderSize = 8 ) -// Encode encodes all the fields of the ipv6 fragment. -func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { - b[nextHdrFrag] = i.NextHeader - binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) - if i.M { - b[more] |= 1 - } - binary.BigEndian.PutUint32(b[idV6:], i.Identification) -} - // IsValid performs basic validation on the fragment header. func (b IPv6Fragment) IsValid() bool { return len(b) >= IPv6FragmentHeaderSize diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index 426a873b1..e3fbd64f3 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -215,48 +215,6 @@ func TestLinkLocalAddrWithOpaqueIID(t *testing.T) { } } -func TestIsV6UniqueLocalAddress(t *testing.T) { - tests := []struct { - name string - addr tcpip.Address - expected bool - }{ - { - name: "Valid Unique 1", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Valid Unique 2", - addr: uniqueLocalAddr1, - expected: true, - }, - { - name: "Link Local", - addr: linkLocalAddr, - expected: false, - }, - { - name: "Global", - addr: globalAddr, - expected: false, - }, - { - name: "IPv4", - addr: "\x01\x02\x03\x04", - expected: false, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if got := header.IsV6UniqueLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6UniqueLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) - } - }) - } -} - func TestIsV6LinkLocalMulticastAddress(t *testing.T) { tests := []struct { name string @@ -346,7 +304,7 @@ func TestScopeForIPv6Address(t *testing.T) { { name: "Unique Local", addr: uniqueLocalAddr1, - scope: header.UniqueLocalScope, + scope: header.GlobalScope, err: nil, }, { diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 0efbfb22b..d9f8e3b35 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -31,7 +31,7 @@ type PacketInfo struct { Pkt *stack.PacketBuffer Proto tcpip.NetworkProtocolNumber GSO *stack.GSO - Route *stack.Route + Route stack.RouteInfo } // Notification is the interface for receiving notification from the packet @@ -230,15 +230,11 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // WritePacket stores outbound packets into the channel. func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } e.q.Write(p) @@ -248,17 +244,13 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne // WritePackets stores outbound packets into the channel. func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { - // Clone r then release its resource so we only get the relevant fields from - // stack.Route without holding a reference to a NIC's endpoint. - route := r.Clone() - route.Release() n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { p := PacketInfo{ Pkt: pkt, Proto: protocol, GSO: gso, - Route: route, + Route: r.GetFields(), } if !e.q.Write(p) { diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index 9f2084eae..cb94cbea6 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -284,9 +284,12 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher } switch sa.(type) { case *unix.SockaddrLinklayer: - // enable PACKET_FANOUT mode is the underlying socket is - // of type AF_PACKET. - const fanoutType = 0x8000 // PACKET_FANOUT_HASH | PACKET_FANOUT_FLAG_DEFRAG + // Enable PACKET_FANOUT mode if the underlying socket is of type + // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will + // prevent gvisor from receiving fragmented packets and the host does the + // reassembly on our behalf before delivering the fragments. This makes it + // hard to test fragmentation reassembly code in Netstack. + const fanoutType = unix.PACKET_FANOUT_HASH fanoutArg := fanoutID | fanoutType<<16 if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index ce4da7230..a87abc6d6 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -323,9 +323,8 @@ func TestPreserveSrcAddress(t *testing.T) { defer c.cleanup() // Set LocalLinkAddress in route to the value of the bridged address. - r := &stack.Route{ - LocalLinkAddress: baddr, - } + var r stack.Route + r.LocalLinkAddress = baddr r.ResolveWith(raddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -335,7 +334,7 @@ func TestPreserveSrcAddress(t *testing.T) { ReserveHeaderBytes: header.EthernetMinimumSize, Data: buffer.VectorisedView{}, }) - if err := c.ep.WritePacket(r, nil /* gso */, proto, pkt); err != nil { + if err := c.ep.WritePacket(&r, nil /* gso */, proto, pkt); err != nil { t.Fatalf("WritePacket failed: %v", err) } diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go index 3e4afcdad..b511d3a31 100644 --- a/pkg/tcpip/link/muxed/injectable_test.go +++ b/pkg/tcpip/link/muxed/injectable_test.go @@ -51,7 +51,8 @@ func TestInjectableEndpointDispatch(t *testing.T) { Data: buffer.NewViewFromBytes([]byte{0xFB}).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) @@ -73,7 +74,8 @@ func TestInjectableEndpointDispatchHdrOnly(t *testing.T) { Data: buffer.NewView(0).ToVectorisedView(), }) pkt.TransportHeader().Push(1)[0] = 0xFA - packetRoute := stack.Route{RemoteAddress: dstIP} + var packetRoute stack.Route + packetRoute.RemoteAddress = dstIP endpoint.WritePacket(&packetRoute, nil /* gso */, ipv4.ProtocolNumber, pkt) buf := make([]byte, 6500) bytesRead, err := sock.Read(buf) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 27667f5f0..b7458b620 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -154,8 +154,7 @@ func (e *endpoint) GSOMaxSize() uint32 { func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { // WritePacket caller's do not set the following fields in PacketBuffer // so we populate them here. - newRoute := r.Clone() - pkt.EgressRoute = newRoute + pkt.EgressRoute = r pkt.GSOOptions = gso pkt.NetworkProtocolNumber = protocol d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] @@ -178,11 +177,6 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB for pkt := pkts.Front(); pkt != nil; { d := e.dispatchers[int(pkt.Hash)%len(e.dispatchers)] nxt := pkt.Next() - // Since qdisc can hold onto a packet for long we should Clone - // the route here to ensure it doesn't get released while the - // packet is still in our queue. - newRoute := pkt.EgressRoute.Clone() - pkt.EgressRoute = newRoute if !d.q.enqueue(pkt) { if enqueued > 0 { d.newPacketWaker.Assert() diff --git a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go index eb5abb906..45adcbccb 100644 --- a/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go +++ b/pkg/tcpip/link/qdisc/fifo/packet_buffer_queue.go @@ -61,6 +61,7 @@ func (q *packetBufferQueue) enqueue(s *stack.PacketBuffer) bool { q.mu.Lock() r := q.used < q.limit if r { + s.EgressRoute.Acquire() q.list.PushBack(s) q.used++ } diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 7131392cc..dd2e1a125 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -340,9 +340,8 @@ func TestPreserveSrcAddressInSend(t *testing.T) { newLocalLinkAddress := tcpip.LinkAddress(strings.Repeat("0xFE", 6)) // Set both remote and local link address in route. - r := stack.Route{ - LocalLinkAddress: newLocalLinkAddress, - } + var r stack.Route + r.LocalLinkAddress = newLocalLinkAddress r.ResolveWith(remoteLinkAddr) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 8d9a91020..1a2cc39eb 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -263,7 +263,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe fragmentOffset = fragOffset case header.ARPProtocolNumber: - if parse.ARP(pkt) { + if !parse.ARP(pkt) { return } diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index a364c5801..bfac358f4 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -264,7 +264,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { // If the packet does not already have link layer header, and the route // does not exist, we can't compute it. This is possibly a raw packet, tun // device doesn't support this at the moment. - if info.Pkt.LinkHeader().View().IsEmpty() && info.Route.RemoteLinkAddress() == "" { + if info.Pkt.LinkHeader().View().IsEmpty() && len(info.Route.RemoteLinkAddress) == 0 { return nil, false } @@ -272,7 +272,7 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader().View().IsEmpty() { - d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress(), info.Proto, info.Pkt) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } vv.AppendView(info.Pkt.LinkHeader().View()) } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 0fb373612..a25cba513 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -441,9 +441,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +556,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d8e4a3b54..429af69ee 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -18,7 +18,6 @@ go_template_instance( go_library( name = "fragmentation", srcs = [ - "frag_heap.go", "fragmentation.go", "reassembler.go", "reassembler_list.go", @@ -38,7 +37,6 @@ go_test( name = "fragmentation_test", size = "small", srcs = [ - "frag_heap_test.go", "fragmentation_test.go", "reassembler_test.go", ], diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go deleted file mode 100644 index 0b570d25a..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "fmt" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -type fragment struct { - offset uint16 - vv buffer.VectorisedView -} - -type fragHeap []fragment - -func (h *fragHeap) Len() int { - return len(*h) -} - -func (h *fragHeap) Less(i, j int) bool { - return (*h)[i].offset < (*h)[j].offset -} - -func (h *fragHeap) Swap(i, j int) { - (*h)[i], (*h)[j] = (*h)[j], (*h)[i] -} - -func (h *fragHeap) Push(x interface{}) { - *h = append(*h, x.(fragment)) -} - -func (h *fragHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[:n-1] - return x -} - -// reassamble empties the heap and returns a VectorisedView -// containing a reassambled version of the fragments inside the heap. -func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { - curr := heap.Pop(h).(fragment) - views := curr.vv.Views() - size := curr.vv.Size() - - if curr.offset != 0 { - return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) - } - - for h.Len() > 0 { - curr := heap.Pop(h).(fragment) - if int(curr.offset) < size { - curr.vv.TrimFront(size - int(curr.offset)) - } else if int(curr.offset) > size { - return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) - } - size += curr.vv.Size() - views = append(views, curr.vv.Views()...) - } - return buffer.NewVectorisedView(size, views), nil -} diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go deleted file mode 100644 index 9ececcb9f..000000000 --- a/pkg/tcpip/network/fragmentation/frag_heap_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package fragmentation - -import ( - "container/heap" - "reflect" - "testing" - - "gvisor.dev/gvisor/pkg/tcpip/buffer" -) - -var reassambleTestCases = []struct { - comment string - in []fragment - want buffer.VectorisedView -}{ - { - comment: "Non-overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Non-overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(2, "0", "1"), - }, - { - comment: "Duplicated packets", - in: []fragment{ - {offset: 0, vv: vv(1, "0")}, - {offset: 0, vv: vv(1, "0")}, - }, - want: vv(1, "0"), - }, - { - comment: "Overlapping in-order", - in: []fragment{ - {offset: 0, vv: vv(2, "01")}, - {offset: 1, vv: vv(2, "12")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping out-of-order", - in: []fragment{ - {offset: 1, vv: vv(2, "12")}, - {offset: 0, vv: vv(2, "01")}, - }, - want: vv(3, "01", "2"), - }, - { - comment: "Overlapping subset in-order", - in: []fragment{ - {offset: 0, vv: vv(3, "012")}, - {offset: 1, vv: vv(1, "1")}, - }, - want: vv(3, "012"), - }, - { - comment: "Overlapping subset out-of-order", - in: []fragment{ - {offset: 1, vv: vv(1, "1")}, - {offset: 0, vv: vv(3, "012")}, - }, - want: vv(3, "012"), - }, -} - -func TestReassamble(t *testing.T) { - for _, c := range reassambleTestCases { - t.Run(c.comment, func(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - for _, f := range c.in { - heap.Push(&h, f) - } - got, err := h.reassemble() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, c.want) { - t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) - } - }) - } -} - -func TestReassambleFailsForNonZeroOffset(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when the first packet had offset != 0") - } -} - -func TestReassambleFailsForHoles(t *testing.T) { - h := make(fragHeap, 0, 8) - heap.Init(&h) - heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) - heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) - _, err := h.reassemble() - if err == nil { - t.Errorf("reassemble() did not fail when there was a hole in the packet") - } -} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index d31296a41..1af87d713 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -53,6 +53,10 @@ var ( // ErrFragmentOverlap indicates that, during reassembly, a fragment overlaps // with another one. ErrFragmentOverlap = errors.New("overlapping fragments") + + // ErrFragmentConflict indicates that, during reassembly, some fragments are + // in conflict with one another. + ErrFragmentConflict = errors.New("conflicting fragments") ) // FragmentID is the identifier for a fragment. diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 04072d966..9b20bb1d8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -15,9 +15,8 @@ package fragmentation import ( - "container/heap" - "fmt" "math" + "sort" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -29,6 +28,8 @@ type hole struct { first uint16 last uint16 filled bool + final bool + data buffer.View } type reassembler struct { @@ -39,7 +40,6 @@ type reassembler struct { mu sync.Mutex holes []hole filled int - heap fragHeap done bool creationTime int64 pkt *stack.PacketBuffer @@ -48,51 +48,71 @@ type reassembler struct { func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { r := &reassembler{ id: id, - holes: make([]hole, 0, 16), - heap: make(fragHeap, 0, 8), creationTime: clock.NowMonotonic(), } r.holes = append(r.holes, hole{ first: 0, last: math.MaxUint16, filled: false, + final: true, }) return r } -// updateHoles updates the list of holes for an incoming fragment. It returns -// true if the fragment fits, it is not a duplicate and it does not overlap with -// another fragment. -// -// For IPv6, overlaps with an existing fragment are explicitly forbidden by -// RFC 8200 section 4.5: -// If any of the fragments being reassembled overlap with any other fragments -// being reassembled for the same packet, reassembly of that packet must be -// abandoned and all the fragments that have been received for that packet -// must be discarded, and no ICMP error messages should be sent. -// -// It is not explicitly forbidden for IPv4, but to keep parity with Linux we -// disallow it as well: -// https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 -func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { + r.mu.Lock() + defer r.mu.Unlock() + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return buffer.VectorisedView{}, 0, false, 0, nil + } + + var holeFound bool + var consumed int for i := range r.holes { currentHole := &r.holes[i] - if currentHole.filled || last < currentHole.first || currentHole.last < first { + if last < currentHole.first || currentHole.last < first { continue } - + // For IPv6, overlaps with an existing fragment are explicitly forbidden by + // RFC 8200 section 4.5: + // If any of the fragments being reassembled overlap with any other + // fragments being reassembled for the same packet, reassembly of that + // packet must be abandoned and all the fragments that have been received + // for that packet must be discarded, and no ICMP error messages should be + // sent. + // + // It is not explicitly forbidden for IPv4, but to keep parity with Linux we + // disallow it as well: + // https://github.com/torvalds/linux/blob/38525c6/net/ipv4/inet_fragment.c#L349 if first < currentHole.first || currentHole.last < last { // Incoming fragment only partially fits in the free hole. - return false, ErrFragmentOverlap + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentOverlap + } + if !more { + if !currentHole.final || currentHole.filled && currentHole.last != last { + // We have another final fragment, which does not perfectly overlap. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict + } } - r.filled++ + holeFound = true + if currentHole.filled { + // Incoming fragment is a duplicate. + continue + } + + // We are populating the current hole with the payload and creating a new + // hole for any unfilled ranges on either end. if first > currentHole.first { r.holes = append(r.holes, hole{ first: currentHole.first, last: first - 1, filled: false, + final: false, }) } if last < currentHole.last && more { @@ -100,39 +120,22 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) (bool, error) { first: last + 1, last: currentHole.last, filled: false, + final: currentHole.final, }) + currentHole.final = false } + v := pkt.Data.ToOwnedView() + consumed = v.Size() + r.size += consumed // Update the current hole to precisely match the incoming fragment. r.holes[i] = hole{ first: first, last: last, filled: true, + final: currentHole.final, + data: v, } - return true, nil - } - - // Incoming fragment is a duplicate/subset, or its offset comes after the end - // of the reassembled payload. - return false, nil -} - -func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *stack.PacketBuffer) (buffer.VectorisedView, uint8, bool, int, error) { - r.mu.Lock() - defer r.mu.Unlock() - if r.done { - // A concurrent goroutine might have already reassembled - // the packet and emptied the heap while this goroutine - // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, 0, false, 0, nil - } - - used, err := r.updateHoles(first, last, more) - if err != nil { - return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) - } - - var consumed int - if used { + r.filled++ // For IPv6, it is possible to have different Protocol values between // fragments of a packet (because, unlike IPv4, the Protocol is not used to // identify a fragment). In this case, only the Protocol of the first @@ -145,22 +148,30 @@ func (r *reassembler) process(first, last uint16, more bool, proto uint8, pkt *s r.pkt = pkt r.proto = proto } - vv := pkt.Data - // We store the incoming packet only if it filled some holes. - heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) - consumed = vv.Size() - r.size += consumed + + break + } + if !holeFound { + // Incoming fragment is beyond end. + return buffer.VectorisedView{}, 0, false, 0, ErrFragmentConflict } // Check if all the holes have been filled and we are ready to reassemble. if r.filled < len(r.holes) { return buffer.VectorisedView{}, 0, false, consumed, nil } - res, err := r.heap.reassemble() - if err != nil { - return buffer.VectorisedView{}, 0, false, 0, fmt.Errorf("fragment reassembly failed: %w", err) + + sort.Slice(r.holes, func(i, j int) bool { + return r.holes[i].first < r.holes[j].first + }) + + var size int + views := make([]buffer.View, 0, len(r.holes)) + for _, hole := range r.holes { + views = append(views, hole.data) + size += hole.data.Size() } - return res, r.proto, true, consumed, nil + return buffer.NewVectorisedView(size, views), r.proto, true, consumed, nil } func (r *reassembler) checkDoneOrMark() bool { diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index cee3063b1..2ff03eeeb 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -19,105 +19,156 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/faketime" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) -type updateHolesParams struct { +type processParams struct { first uint16 last uint16 more bool - wantUsed bool + pkt *stack.PacketBuffer + wantDone bool wantError error } -func TestUpdateHoles(t *testing.T) { +func TestReassemblerProcess(t *testing.T) { + const proto = 99 + + v := func(size int) buffer.View { + payload := buffer.NewView(size) + for i := 1; i < size; i++ { + payload[i] = uint8(i) * 3 + } + return payload + } + + pkt := func(size int) *stack.PacketBuffer { + return stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: v(size).ToVectorisedView(), + }) + } + var tests = []struct { name string - params []updateHolesParams + params []processParams want []hole }{ { name: "No fragments", params: nil, - want: []hole{{first: 0, last: math.MaxUint16, filled: false}}, + want: []hole{{first: 0, last: math.MaxUint16, filled: false, final: true}}, }, { name: "One fragment at beginning", - params: []updateHolesParams{{first: 0, last: 1, more: true, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: math.MaxUint16, filled: false}, + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: math.MaxUint16, filled: false, final: true}, }, }, { name: "One fragment in the middle", - params: []updateHolesParams{{first: 1, last: 2, more: true, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true}, - {first: 0, last: 0, filled: false}, - {first: 3, last: math.MaxUint16, filled: false}, + {first: 1, last: 2, filled: true, final: false, data: v(2)}, + {first: 0, last: 0, filled: false, final: false}, + {first: 3, last: math.MaxUint16, filled: false, final: true}, }, }, { name: "One fragment at the end", - params: []updateHolesParams{{first: 1, last: 2, more: false, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 1, last: 2, more: false, pkt: pkt(2), wantDone: false, wantError: nil}}, want: []hole{ - {first: 1, last: 2, filled: true}, + {first: 1, last: 2, filled: true, final: true, data: v(2)}, {first: 0, last: 0, filled: false}, }, }, { name: "One fragment completing a packet", - params: []updateHolesParams{{first: 0, last: 1, more: false, wantUsed: true, wantError: nil}}, + params: []processParams{{first: 0, last: 1, more: false, pkt: pkt(2), wantDone: true, wantError: nil}}, want: []hole{ - {first: 0, last: 1, filled: true}, + {first: 0, last: 1, filled: true, final: true, data: v(2)}, }, }, { name: "Two fragments completing a packet", - params: []updateHolesParams{ - {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, - {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: 3, filled: true}, + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, }, }, { name: "Two fragments completing a packet with a duplicate", - params: []updateHolesParams{ - {first: 0, last: 1, more: true, wantUsed: true, wantError: nil}, - {first: 0, last: 1, more: true, wantUsed: false, wantError: nil}, - {first: 2, last: 3, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 0, last: 1, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 2, last: 3, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, + }, + want: []hole{ + {first: 0, last: 1, filled: true, final: false, data: v(2)}, + {first: 2, last: 3, filled: true, final: true, data: v(2)}, + }, + }, + { + name: "Two fragments completing a packet with a partial duplicate", + params: []processParams{ + {first: 0, last: 3, more: true, pkt: pkt(4), wantDone: false, wantError: nil}, + {first: 1, last: 2, more: true, pkt: pkt(2), wantDone: false, wantError: nil}, + {first: 4, last: 5, more: false, pkt: pkt(2), wantDone: true, wantError: nil}, }, want: []hole{ - {first: 0, last: 1, filled: true}, - {first: 2, last: 3, filled: true}, + {first: 0, last: 3, filled: true, final: false, data: v(4)}, + {first: 4, last: 5, filled: true, final: true, data: v(2)}, }, }, { name: "Two overlapping fragments", - params: []updateHolesParams{ - {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, - {first: 5, last: 15, more: false, wantUsed: false, wantError: ErrFragmentOverlap}, - {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, + params: []processParams{ + {first: 0, last: 10, more: true, pkt: pkt(11), wantDone: false, wantError: nil}, + {first: 5, last: 15, more: false, pkt: pkt(11), wantDone: false, wantError: ErrFragmentOverlap}, + }, + want: []hole{ + {first: 0, last: 10, filled: true, final: false, data: v(11)}, + {first: 11, last: math.MaxUint16, filled: false, final: true}, + }, + }, + { + name: "Two final fragments with different ends", + params: []processParams{ + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, + {first: 0, last: 9, more: false, pkt: pkt(10), wantDone: false, wantError: ErrFragmentConflict}, + }, + want: []hole{ + {first: 10, last: 14, filled: true, final: true, data: v(5)}, + {first: 0, last: 9, filled: false, final: false}, + }, + }, + { + name: "Two final fragments - duplicate", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 14, more: false, pkt: pkt(5), wantDone: false, wantError: nil}, }, want: []hole{ - {first: 0, last: 10, filled: true}, - {first: 11, last: 15, filled: true}, + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, }, }, { - name: "Out of bounds fragment", - params: []updateHolesParams{ - {first: 0, last: 10, more: true, wantUsed: true, wantError: nil}, - {first: 11, last: 15, more: false, wantUsed: true, wantError: nil}, - {first: 16, last: 20, more: false, wantUsed: false, wantError: nil}, + name: "Two final fragments - duplicate, with different ends", + params: []processParams{ + {first: 5, last: 14, more: false, pkt: pkt(10), wantDone: false, wantError: nil}, + {first: 10, last: 13, more: false, pkt: pkt(4), wantDone: false, wantError: ErrFragmentConflict}, }, want: []hole{ - {first: 0, last: 10, filled: true}, - {first: 11, last: 15, filled: true}, + {first: 5, last: 14, filled: true, final: true, data: v(10)}, + {first: 0, last: 4, filled: false, final: false}, }, }, } @@ -126,9 +177,9 @@ func TestUpdateHoles(t *testing.T) { t.Run(test.name, func(t *testing.T) { r := newReassembler(FragmentID{}, &faketime.NullClock{}) for _, param := range test.params { - used, err := r.updateHoles(param.first, param.last, param.more) - if used != param.wantUsed || err != param.wantError { - t.Errorf("got r.updateHoles(%d, %d, %t) = (%t, %v), want = (%t, %v)", param.first, param.last, param.more, used, err, param.wantUsed, param.wantError) + _, _, done, _, err := r.process(param.first, param.last, param.more, proto, param.pkt) + if done != param.wantDone || err != param.wantError { + t.Errorf("got r.process(%d, %d, %t, %d, _) = (_, _, %t, _, %v), want = (%t, %v)", param.first, param.last, param.more, proto, done, err, param.wantDone, param.wantError) } } if diff := cmp.Diff(test.want, r.holes, cmp.AllowUnexported(hole{})); diff != "" { diff --git a/pkg/tcpip/network/ip/BUILD b/pkg/tcpip/network/ip/BUILD index 6ca200b48..ca1247c1e 100644 --- a/pkg/tcpip/network/ip/BUILD +++ b/pkg/tcpip/network/ip/BUILD @@ -18,6 +18,7 @@ go_test( srcs = ["generic_multicast_protocol_test.go"], deps = [ ":ip", + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/faketime", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol.go b/pkg/tcpip/network/ip/generic_multicast_protocol.go index e308550c4..f2f0e069c 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol.go @@ -30,6 +30,23 @@ type hostState int // The states below are generic across IGMPv2 (RFC 2236 section 6) and MLDv1 // (RFC 2710 section 5). Even though the states are generic across both IGMPv2 // and MLDv1, IGMPv2 terminology will be used. +// +// ______________receive query______________ +// | | +// | _____send or receive report_____ | +// | | | | +// V | V | +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ | +// | Non-M | | Pending-M | | Delaying-M | | Queued Delaying-M | | Idle-M | - +// +-------+ +-----------+ +------------+ +-------------------+ +--------+ +// | ^ | ^ | ^ | ^ +// | | | | | | | | +// ---------- ------- ---------- ------------- +// initialize new send inital fail to send send or receive +// group membership report delayed report report +// +// Not shown in the diagram above, but any state may transition into the non +// member state when a group is left. const ( // nonMember is the "'Non-Member' state, when the host does not belong to the // group on the interface. This is the initial state for all memberships on @@ -41,6 +58,15 @@ const ( // but without advertising the membership to the network. nonMember hostState = iota + // pendingMember is a newly joined member that is waiting to successfully send + // the initial set of reports. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the initial report needs to be sent. + // + // MAY NOT transition to the idle member state from this state. + pendingMember + // delayingMember is the "'Delaying Member' state, when the host belongs to // the group on the interface and has a report delay timer running for that // membership." @@ -48,6 +74,16 @@ const ( // 'Delaying Listener' is the MLDv1 term used to describe this state. delayingMember + // queuedDelayingMember is a delayingMember that failed to send a report after + // its delayed report timer fired. Hosts in this state are waiting to attempt + // retransmission of the delayed report. + // + // This is not an RFC defined state; it is an implementation specific state to + // track that the delayed report needs to be sent. + // + // May transition to idle member if a report is received for a group. + queuedDelayingMember + // idleMember is the "Idle Member" state, when the host belongs to the group // on the interface and does not have a report delay timer running for that // membership. @@ -56,6 +92,17 @@ const ( idleMember ) +func (s hostState) isDelayingMember() bool { + switch s { + case nonMember, pendingMember, idleMember: + return false + case delayingMember, queuedDelayingMember: + return true + default: + panic(fmt.Sprintf("unrecognized host state = %d", s)) + } +} + // multicastGroupState holds the Generic Multicast Protocol state for a // multicast group. type multicastGroupState struct { @@ -84,17 +131,6 @@ type multicastGroupState struct { // GenericMulticastProtocolOptions holds options for the generic multicast // protocol. type GenericMulticastProtocolOptions struct { - // Enabled indicates whether the generic multicast protocol will be - // performed. - // - // When enabled, the protocol may transmit report and leave messages when - // joining and leaving multicast groups respectively, and handle incoming - // packets. - // - // When disabled, the protocol will still keep track of locally joined groups, - // it just won't transmit and handle packets, or update groups' state. - Enabled bool - // Rand is the source of random numbers. Rand *rand.Rand @@ -123,8 +159,22 @@ type GenericMulticastProtocolOptions struct { // MulticastGroupProtocol is a multicast group protocol whose core state machine // can be represented by GenericMulticastProtocolState. type MulticastGroupProtocol interface { + // Enabled indicates whether the generic multicast protocol will be + // performed. + // + // When enabled, the protocol may transmit report and leave messages when + // joining and leaving multicast groups respectively, and handle incoming + // packets. + // + // When disabled, the protocol will still keep track of locally joined groups, + // it just won't transmit and handle packets, or update groups' state. + Enabled() bool + // SendReport sends a multicast report for the specified group address. - SendReport(groupAddress tcpip.Address) *tcpip.Error + // + // Returns false if the caller should queue the report to be sent later. Note, + // returning false does not mean that the receiver hit an error. + SendReport(groupAddress tcpip.Address) (sent bool, err *tcpip.Error) // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) *tcpip.Error @@ -138,76 +188,119 @@ type MulticastGroupProtocol interface { // IPv4 and IPv6. Specifically, Generic Multicast Protocol is the core state // machine of IGMPv2 as defined by RFC 2236 and MLDv1 as defined by RFC 2710. // +// Callers must synchronize accesses to the generic multicast protocol state; +// GenericMulticastProtocolState obtains no locks in any of its methods. The +// only exception to this is GenericMulticastProtocolState's timer/job callbacks +// which will obtain the lock provided to the GenericMulticastProtocolState when +// it is initialized. +// // GenericMulticastProtocolState.Init MUST be called before calling any of // the methods on GenericMulticastProtocolState. +// +// GenericMulticastProtocolState.MakeAllNonMemberLocked MUST be called when the +// multicast group protocol is disabled so that leave messages may be sent. type GenericMulticastProtocolState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + opts GenericMulticastProtocolOptions - mu struct { - sync.RWMutex + // memberships holds group addresses and their associated state. + memberships map[tcpip.Address]multicastGroupState - // memberships holds group addresses and their associated state. - memberships map[tcpip.Address]multicastGroupState - } + // protocolMU is the mutex used to protect the protocol. + protocolMU *sync.RWMutex } // Init initializes the Generic Multicast Protocol state. -func (g *GenericMulticastProtocolState) Init(opts GenericMulticastProtocolOptions) { - g.mu.Lock() - defer g.mu.Unlock() - g.opts = opts - g.mu.memberships = make(map[tcpip.Address]multicastGroupState) +// +// Must only be called once for the lifetime of g; Init will panic if it is +// called twice. +// +// The GenericMulticastProtocolState will only grab the lock when timers/jobs +// fire. +// +// Note: the methods on opts.Protocol will always be called while protocolMU is +// held. +func (g *GenericMulticastProtocolState) Init(protocolMU *sync.RWMutex, opts GenericMulticastProtocolOptions) { + if g.memberships != nil { + panic("attempted to initialize generic membership protocol state twice") + } + + *g = GenericMulticastProtocolState{ + opts: opts, + memberships: make(map[tcpip.Address]multicastGroupState), + protocolMU: protocolMU, + } } -// MakeAllNonMember transitions all groups to the non-member state. +// MakeAllNonMemberLocked transitions all groups to the non-member state. // // The groups will still be considered joined locally. -func (g *GenericMulticastProtocolState) MakeAllNonMember() { - if !g.opts.Enabled { +// +// MUST be called when the multicast group protocol is disabled. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.transitionToNonMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// InitializeGroups initializes each group, as if they were newly joined but -// without affecting the groups' join count. +// InitializeGroupsLocked initializes each group, as if they were newly joined +// but without affecting the groups' join count. // // Must only be called after calling MakeAllNonMember as a group should not be // initialized while it is not in the non-member state. -func (g *GenericMulticastProtocolState) InitializeGroups() { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) InitializeGroupsLocked() { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.initializeNewMemberLocked(groupAddress, &info) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// JoinGroup handles joining a new group. +// SendQueuedReportsLocked attempts to send reports for groups that failed to +// send reports during their last attempt. // -// If dontInitialize is true, the group will be not be initialized and will be -// left in the non-member state - no packets will be sent for it until it is -// initialized via InitializeGroups. -func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, dontInitialize bool) { - g.mu.Lock() - defer g.mu.Unlock() +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) SendQueuedReportsLocked() { + for groupAddress, info := range g.memberships { + switch info.state { + case nonMember, delayingMember, idleMember: + case pendingMember: + // pendingMembers failed to send their initial unsolicited report so try + // to send the report and queue the extra unsolicited reports. + g.maybeSendInitialReportLocked(groupAddress, &info) + case queuedDelayingMember: + // queuedDelayingMembers failed to send their delayed reports so try to + // send the report and transition them to the idle state. + g.maybeSendDelayedReportLocked(groupAddress, &info) + default: + panic(fmt.Sprintf("unrecognized host state = %d", info.state)) + } + g.memberships[groupAddress] = info + } +} - if info, ok := g.mu.memberships[groupAddress]; ok { +// JoinGroupLocked handles joining a new group. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) JoinGroupLocked(groupAddress tcpip.Address) { + if info, ok := g.memberships[groupAddress]; ok { // The group has already been joined. info.joins++ - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return } @@ -217,41 +310,43 @@ func (g *GenericMulticastProtocolState) JoinGroup(groupAddress tcpip.Address, do // The state will be updated below, if required. state: nonMember, lastToSendReport: false, - delayedReportJob: tcpip.NewJob(g.opts.Clock, &g.mu, func() { - info, ok := g.mu.memberships[groupAddress] + delayedReportJob: tcpip.NewJob(g.opts.Clock, g.protocolMU, func() { + if !g.opts.Protocol.Enabled() { + panic(fmt.Sprintf("delayed report job fired for group %s while the multicast group protocol is disabled", groupAddress)) + } + + info, ok := g.memberships[groupAddress] if !ok { panic(fmt.Sprintf("expected to find group state for group = %s", groupAddress)) } - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - info.state = idleMember - g.mu.memberships[groupAddress] = info + g.maybeSendDelayedReportLocked(groupAddress, &info) + g.memberships[groupAddress] = info }), } - if !dontInitialize && g.opts.Enabled { + if g.opts.Protocol.Enabled() { g.initializeNewMemberLocked(groupAddress, &info) } - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } -// IsLocallyJoined returns true if the group is locally joined. -func (g *GenericMulticastProtocolState) IsLocallyJoined(groupAddress tcpip.Address) bool { - g.mu.RLock() - defer g.mu.RUnlock() - _, ok := g.mu.memberships[groupAddress] +// IsLocallyJoinedRLocked returns true if the group is locally joined. +// +// Precondition: g.protocolMU must be read locked. +func (g *GenericMulticastProtocolState) IsLocallyJoinedRLocked(groupAddress tcpip.Address) bool { + _, ok := g.memberships[groupAddress] return ok } -// LeaveGroup handles leaving the group. +// LeaveGroupLocked handles leaving the group. // // Returns false if the group is not currently joined. -func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) bool { - g.mu.Lock() - defer g.mu.Unlock() - - info, ok := g.mu.memberships[groupAddress] +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) LeaveGroupLocked(groupAddress tcpip.Address) bool { + info, ok := g.memberships[groupAddress] if !ok { return false } @@ -262,30 +357,30 @@ func (g *GenericMulticastProtocolState) LeaveGroup(groupAddress tcpip.Address) b info.joins-- if info.joins != 0 { // If we still have outstanding joins, then do nothing further. - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info return true } g.transitionToNonMemberLocked(groupAddress, &info) - delete(g.mu.memberships, groupAddress) + delete(g.memberships, groupAddress) return true } -// HandleQuery handles a query message with the specified maximum response time. +// HandleQueryLocked handles a query message with the specified maximum response +// time. // // If the group address is unspecified, then reports will be scheduled for all // joined groups. // // Report(s) will be scheduled to be sent after a random duration between 0 and // the maximum response time. -func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, maxResponseTime time.Duration) { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleQueryLocked(groupAddress tcpip.Address, maxResponseTime time.Duration) { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 2.4 (for IGMPv2), // // In a Membership Query message, the group address field is set to zero @@ -299,28 +394,27 @@ func (g *GenericMulticastProtocolState) HandleQuery(groupAddress tcpip.Address, // when sending a Multicast-Address-Specific Query. if groupAddress.Unspecified() { // This is a general query as the group address is unspecified. - for groupAddress, info := range g.mu.memberships { + for groupAddress, info := range g.memberships { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } - } else if info, ok := g.mu.memberships[groupAddress]; ok { + } else if info, ok := g.memberships[groupAddress]; ok { g.setDelayTimerForAddressRLocked(groupAddress, &info, maxResponseTime) - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } -// HandleReport handles a report message. +// HandleReportLocked handles a report message. // // If the report is for a joined group, any active delayed report will be // cancelled and the host state for the group transitions to idle. -func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) { - if !g.opts.Enabled { +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) HandleReportLocked(groupAddress tcpip.Address) { + if !g.opts.Protocol.Enabled() { return } - g.mu.Lock() - defer g.mu.Unlock() - // As per RFC 2236 section 3 pages 3-4 (for IGMPv2), // // If the host receives another host's Report (version 1 or 2) while it has @@ -333,23 +427,23 @@ func (g *GenericMulticastProtocolState) HandleReport(groupAddress tcpip.Address) // multicast address while it has a timer running for that same address // on that interface, it stops its timer and does not send a Report for // that address, thus suppressing duplicate reports on the link. - if info, ok := g.mu.memberships[groupAddress]; ok && info.state == delayingMember { + if info, ok := g.memberships[groupAddress]; ok && info.state.isDelayingMember() { info.delayedReportJob.Cancel() info.lastToSendReport = false info.state = idleMember - g.mu.memberships[groupAddress] = info + g.memberships[groupAddress] = info } } // initializeNewMemberLocked initializes a new group membership. // -// Precondition: g.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state != nonMember { - panic(fmt.Sprintf("state for group %s is not non-member; state = %d", groupAddress, info.state)) + panic(fmt.Sprintf("host must be in non-member state to be initialized; group = %s, state = %d", groupAddress, info.state)) } - info.state = idleMember + info.lastToSendReport = false if groupAddress == g.opts.AllNodesAddress { // As per RFC 2236 section 6 page 10 (for IGMPv2), @@ -365,9 +459,25 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // case. The node starts in Idle Listener state for that address on // every interface, never transitions to another state, and never sends // a Report or Done for that address. + info.state = idleMember return } + info.state = pendingMember + g.maybeSendInitialReportLocked(groupAddress, info) +} + +// maybeSendInitialReportLocked attempts to start transmission of the initial +// set of reports after newly joining a group. +// +// Host must be in pending member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendInitialReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if info.state != pendingMember { + panic(fmt.Sprintf("host must be in pending member state to send initial reports; group = %s, state = %d", groupAddress, info.state)) + } + // As per RFC 2236 section 3 page 5 (for IGMPv2), // // When a host joins a multicast group, it should immediately transmit an @@ -385,13 +495,35 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t // // TODO(gvisor.dev/issue/4901): Support a configurable number of initial // unsolicited reports. - info.lastToSendReport = g.opts.Protocol.SendReport(groupAddress) == nil - g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + g.setDelayTimerForAddressRLocked(groupAddress, info, g.opts.MaxUnsolicitedReportDelay) + } +} + +// maybeSendDelayedReportLocked attempts to send the delayed report. +// +// Host must be in pending, delaying or queued delaying member state. +// +// Precondition: g.protocolMU must be locked. +func (g *GenericMulticastProtocolState) maybeSendDelayedReportLocked(groupAddress tcpip.Address, info *multicastGroupState) { + if !info.state.isDelayingMember() { + panic(fmt.Sprintf("host must be in delaying or queued delaying member state to send delayed reports; group = %s, state = %d", groupAddress, info.state)) + } + + sent, err := g.opts.Protocol.SendReport(groupAddress) + if err == nil && sent { + info.lastToSendReport = true + info.state = idleMember + } else { + info.state = queuedDelayingMember + } } // maybeSendLeave attempts to send a leave message. func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Address, lastToSendReport bool) { - if !g.opts.Enabled || !lastToSendReport { + if !g.opts.Protocol.Enabled() || !lastToSendReport { return } @@ -465,7 +597,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres // transitionToNonMemberLocked transitions the given multicast group the the // non-member/listener state. // -// Precondition: e.mu must be locked. +// Precondition: g.protocolMU must be locked. func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress tcpip.Address, info *multicastGroupState) { if info.state == nonMember { return @@ -479,7 +611,7 @@ func (g *GenericMulticastProtocolState) transitionToNonMemberLocked(groupAddress // setDelayTimerForAddressRLocked sets timer to send a delay report. // -// Precondition: g.mu MUST be read locked. +// Precondition: g.protocolMU MUST be read locked. func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddress tcpip.Address, info *multicastGroupState, maxResponseTime time.Duration) { if info.state == nonMember { return @@ -517,6 +649,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr // TODO: Reset the timer if time remaining is greater than maxResponseTime. return } + info.state = delayingMember info.delayedReportJob.Cancel() info.delayedReportJob.Schedule(g.calculateDelayTimerDuration(maxResponseTime)) diff --git a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go index 670be30d4..85593f211 100644 --- a/pkg/tcpip/network/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/ip/generic_multicast_protocol_test.go @@ -20,6 +20,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/network/ip" @@ -36,42 +37,178 @@ const ( var _ ip.MulticastGroupProtocol = (*mockMulticastGroupProtocol)(nil) -type mockMulticastGroupProtocol struct { +type mockMulticastGroupProtocolProtectedFields struct { + sync.RWMutex + + genericMulticastGroup ip.GenericMulticastProtocolState sendReportGroupAddrCount map[tcpip.Address]int sendLeaveGroupAddrCount map[tcpip.Address]int + makeQueuePackets bool + disabled bool } -func (m *mockMulticastGroupProtocol) init() { - m.sendReportGroupAddrCount = make(map[tcpip.Address]int) - m.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +type mockMulticastGroupProtocol struct { + t *testing.T + + mu mockMulticastGroupProtocolProtectedFields } -func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) *tcpip.Error { - m.sendReportGroupAddrCount[groupAddress]++ - return nil +func (m *mockMulticastGroupProtocol) init(opts ip.GenericMulticastProtocolOptions) { + m.mu.Lock() + defer m.mu.Unlock() + m.initLocked() + opts.Protocol = m + m.mu.genericMulticastGroup.Init(&m.mu.RWMutex, opts) +} + +func (m *mockMulticastGroupProtocol) initLocked() { + m.mu.sendReportGroupAddrCount = make(map[tcpip.Address]int) + m.mu.sendLeaveGroupAddrCount = make(map[tcpip.Address]int) +} + +func (m *mockMulticastGroupProtocol) setEnabled(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.disabled = !v +} + +func (m *mockMulticastGroupProtocol) setQueuePackets(v bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.makeQueuePackets = v } +func (m *mockMulticastGroupProtocol) joinGroup(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.JoinGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) leaveGroup(addr tcpip.Address) bool { + m.mu.Lock() + defer m.mu.Unlock() + return m.mu.genericMulticastGroup.LeaveGroupLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleReport(addr tcpip.Address) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleReportLocked(addr) +} + +func (m *mockMulticastGroupProtocol) handleQuery(addr tcpip.Address, maxRespTime time.Duration) { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.HandleQueryLocked(addr, maxRespTime) +} + +func (m *mockMulticastGroupProtocol) isLocallyJoined(addr tcpip.Address) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.mu.genericMulticastGroup.IsLocallyJoinedRLocked(addr) +} + +func (m *mockMulticastGroupProtocol) makeAllNonMember() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.MakeAllNonMemberLocked() +} + +func (m *mockMulticastGroupProtocol) initializeGroups() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.InitializeGroupsLocked() +} + +func (m *mockMulticastGroupProtocol) sendQueuedReports() { + m.mu.Lock() + defer m.mu.Unlock() + m.mu.genericMulticastGroup.SendQueuedReportsLocked() +} + +// Enabled implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be read locked. +func (m *mockMulticastGroupProtocol) Enabled() bool { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatal("got write lock, expected to not take the lock; generic multicast protocol must take the read or write lock before calling Enabled") + } + + return !m.mu.disabled +} + +// SendReport implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. +func (m *mockMulticastGroupProtocol) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending report for %s", groupAddress) + } + + m.mu.sendReportGroupAddrCount[groupAddress]++ + return !m.mu.makeQueuePackets, nil +} + +// SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: m.mu must be locked. func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - m.sendLeaveGroupAddrCount[groupAddress]++ + if m.mu.TryLock() { + m.mu.Unlock() + m.t.Fatalf("got write lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + if m.mu.TryRLock() { + m.mu.RUnlock() + m.t.Fatalf("got read lock, expected to not take the lock; generic multicast protocol must take the write lock before sending leave for %s", groupAddress) + } + + m.mu.sendLeaveGroupAddrCount[groupAddress]++ return nil } -func checkProtocol(mgp *mockMulticastGroupProtocol, sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { - sendReportGroupAddressesMap := make(map[tcpip.Address]int) +func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { + m.mu.Lock() + defer m.mu.Unlock() + + sendReportGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendReportGroupAddresses { - sendReportGroupAddressesMap[a] = 1 + sendReportGroupAddrCount[a] = 1 } - sendLeaveGroupAddressesMap := make(map[tcpip.Address]int) + sendLeaveGroupAddrCount := make(map[tcpip.Address]int) for _, a := range sendLeaveGroupAddresses { - sendLeaveGroupAddressesMap[a] = 1 + sendLeaveGroupAddrCount[a] = 1 } - diff := cmp.Diff(mockMulticastGroupProtocol{ - sendReportGroupAddrCount: sendReportGroupAddressesMap, - sendLeaveGroupAddrCount: sendLeaveGroupAddressesMap, - }, *mgp, cmp.AllowUnexported(mockMulticastGroupProtocol{})) - mgp.init() + diff := cmp.Diff( + &mockMulticastGroupProtocol{ + mu: mockMulticastGroupProtocolProtectedFields{ + sendReportGroupAddrCount: sendReportGroupAddrCount, + sendLeaveGroupAddrCount: sendLeaveGroupAddrCount, + }, + }, + m, + cmp.AllowUnexported(mockMulticastGroupProtocol{}), + cmp.AllowUnexported(mockMulticastGroupProtocolProtectedFields{}), + // ignore mockMulticastGroupProtocol.mu and mockMulticastGroupProtocol.t + cmp.FilterPath( + func(p cmp.Path) bool { + switch p.Last().String() { + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + return true + } + return false + }, + cmp.Ignore(), + ), + ) + m.initLocked() return diff } @@ -95,36 +232,34 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after // a random interval between 0 and the maximum unsolicited report delay. - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.joinGroup(test.addr) if test.shouldSendReports { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -151,40 +286,42 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr2, }) - g.JoinGroup(test.addr, false /* dontInitialize */) + mgp.joinGroup(test.addr) if test.shouldSendMessages { - if diff := checkProtocol(&mgp, []tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{test.addr} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Leaving a group should send a leave report immediately and cancel any // delayed reports. - if !g.LeaveGroup(test.addr) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", test.addr) + { + + if !mgp.leaveGroup(test.addr) { + t.Fatalf("got mgp.leaveGroup(%s) = false, want = true", test.addr) + } } if test.shouldSendMessages { - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{test.addr} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -226,45 +363,43 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a report for a group we have a timer scheduled for should // cancel our delayed report timer for the group. - g.HandleReport(test.reportAddr) + mgp.handleReport(test.reportAddr) if len(test.expectReportsFor) != 0 { + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -312,49 +447,47 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Receiving a query should make us schedule a new delayed report if it // is a query directed at us or a general query. - g.HandleQuery(test.queryAddr, test.maxDelay) + mgp.handleQuery(test.queryAddr, test.maxDelay) if len(test.expectReportsFor) != 0 { clock.Advance(test.maxDelay) - if diff := checkProtocol(&mgp, test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(test.expectReportsFor /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } }) @@ -362,133 +495,139 @@ func TestHandleQuery(t *testing.T) { } func TestJoinCount(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(4)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: time.Second, }) // Set the join count to 2 for a group. - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // Only the first join should trigger a report to be sent. - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr1, false /* dontInitialize */) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Group should still be considered joined after leaving once. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) + if !mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) } // A leave report should only be sent once the join count reaches 0. - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Leaving once more should actually remove us from the group. - if !g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = false, want = true", addr1) + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr1) } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1} /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + if t.Failed() { + t.FailNow() } // Group should no longer be joined so we should not have anything to // leave. - if g.LeaveGroup(addr1) { - t.Fatalf("got g.LeaveGroup(%s) = true, want = false", addr1) + if mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = true, want = false", addr1) } - if g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr1) + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr1) } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. + // + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() + mgp := mockMulticastGroupProtocol{t: t} clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: true, + + mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, - Protocol: &mgp, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, AllNodesAddress: addr3, }) - g.JoinGroup(addr1, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr2, false /* dontInitialize */) - if diff := checkProtocol(&mgp, []tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - g.JoinGroup(addr3, false /* dontInitialize */) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.joinGroup(addr3) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should send the leave reports for each but still consider them locally // joined. - g.MakeAllNonMember() - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { + mgp.makeAllNonMember() + if diff := mgp.check(nil /* sendReportGroupAddresses */, []tcpip.Address{addr1, addr2} /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } + // Generic multicast protocol timers are expected to take the job mutex. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } for _, group := range []tcpip.Address{addr1, addr2, addr3} { - if !g.IsLocallyJoined(group) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", group) + if !mgp.isLocallyJoined(group) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", group) } } // Should send the initial set of unsolcited reports. - g.InitializeGroups() - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + mgp.initializeGroups() + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } clock.Advance(maxUnsolicitedReportDelay) - if diff := checkProtocol(&mgp, []tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check([]tcpip.Address{addr1, addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } // Should have no more messages to send. clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } @@ -496,81 +635,172 @@ func TestMakeAllNonMemberAndInitialize(t *testing.T) { // TestGroupStateNonMember tests that groups do not send packets when in the // non-member state, but are still considered locally joined. func TestGroupStateNonMember(t *testing.T) { - tests := []struct { - name string - enabled bool - dontInitialize bool - }{ - { - name: "Disabled", - enabled: false, - dontInitialize: false, - }, - { - name: "Keep non-member", - enabled: true, - dontInitialize: true, - }, - { - name: "disabled and Keep non-member", - enabled: false, - dontInitialize: true, - }, + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) + mgp.setEnabled(false) + + // Joining groups should not send any reports. + mgp.joinGroup(addr1) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr1) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.joinGroup(addr2) + if !mgp.isLocallyJoined(addr1) { + t.Fatalf("got mgp.isLocallyJoined(%s) = false, want = true", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var g ip.GenericMulticastProtocolState - var mgp mockMulticastGroupProtocol - mgp.init() - clock := faketime.NewManualClock() - g.Init(ip.GenericMulticastProtocolOptions{ - Enabled: test.enabled, - Rand: rand.New(rand.NewSource(3)), - Clock: clock, - Protocol: &mgp, - MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - }) + // Receiving a query should not send any reports. + mgp.handleQuery(addr1, time.Nanosecond) + // Generic multicast protocol timers are expected to take the job mutex. + clock.Advance(time.Nanosecond) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - g.JoinGroup(addr1, test.dontInitialize) - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Leaving groups should not send any leave messages. + if !mgp.leaveGroup(addr1) { + t.Errorf("got mgp.leaveGroup(%s) = false, want = true", addr2) + } + if mgp.isLocallyJoined(addr1) { + t.Errorf("got mgp.isLocallyJoined(%s) = true, want = false", addr2) + } + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - g.JoinGroup(addr2, test.dontInitialize) - if !g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr2) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } +} - g.HandleQuery(addr1, time.Nanosecond) - clock.Advance(time.Nanosecond) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } +func TestQueuedPackets(t *testing.T) { + clock := faketime.NewManualClock() + mgp := mockMulticastGroupProtocol{t: t} + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(4)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }) - if !g.LeaveGroup(addr2) { - t.Errorf("got g.LeaveGroup(%s) = false, want = true", addr2) - } - if !g.IsLocallyJoined(addr1) { - t.Fatalf("got g.IsLocallyJoined(%s) = false, want = true", addr1) - } - if g.IsLocallyJoined(addr2) { - t.Fatalf("got g.IsLocallyJoined(%s) = true, want = false", addr2) - } - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } + // Joining should trigger a SendReport, but mgp should report that we did not + // send the packet. + mgp.setQueuePackets(true) + mgp.joinGroup(addr1) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } - clock.Advance(time.Hour) - if diff := checkProtocol(&mgp, nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { - t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) - } - }) + // The delayed report timer should have been cancelled since we did not send + // the initial report earlier. + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send (we should be idle). + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query but mock being unable to send reports again. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Mock being able to send reports again - we should have a packet queued to + // send. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receive a query again, but mock being unable to send reports. + mgp.setQueuePackets(true) + mgp.handleQuery(addr1, time.Nanosecond) + clock.Advance(time.Nanosecond) + if diff := mgp.check([]tcpip.Address{addr1} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Receiving a report should should transition us into the idle member state, + // even if we had a packet queued. We should no longer have any packets to + // send. + mgp.handleReport(addr1) + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // When we fail to send the initial set of reports, incoming reports should + // not affect a newly joined group's reports from being sent. + mgp.setQueuePackets(true) + mgp.joinGroup(addr2) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + mgp.handleReport(addr2) + // Attempting to send queued reports while still unable to send reports should + // not change the host state. + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // Mock being able to successfully send the report. + mgp.setQueuePackets(false) + mgp.sendQueuedReports() + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + // The delayed report (sent after the initial report) should now be sent. + clock.Advance(maxUnsolicitedReportDelay) + if diff := mgp.check([]tcpip.Address{addr2} /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Should not have anything else to send. + mgp.sendQueuedReports() + clock.Advance(time.Hour) + if diff := mgp.check(nil /* sendReportGroupAddresses */, nil /* sendLeaveGroupAddresses */); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) } } diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a314dd386..3005973d7 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -344,11 +344,11 @@ func TestSourceAddressValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, localIPv6Addr, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: localIPv6Addr, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: localIPv6Addr, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -619,11 +619,11 @@ func TestReceive(t *testing.T) { view := buffer.NewView(header.IPv6MinimumSize + payloadLen) ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: payloadLen, - NextHeader: 10, - HopLimit: ipv6.DefaultTTL, - SrcAddr: remoteIPv6Addr, - DstAddr: localIPv6Addr, + PayloadLength: payloadLen, + TransportProtocol: 10, + HopLimit: ipv6.DefaultTTL, + SrcAddr: remoteIPv6Addr, + DstAddr: localIPv6Addr, }) // Make payload be non-zero. @@ -993,11 +993,11 @@ func TestIPv6ReceiveControl(t *testing.T) { // Create the outer IPv6 header. ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 20, - SrcAddr: outerSrcAddr, - DstAddr: localIPv6Addr, + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 20, + SrcAddr: outerSrcAddr, + DstAddr: localIPv6Addr, }) // Create the ICMP header. @@ -1007,28 +1007,27 @@ func TestIPv6ReceiveControl(t *testing.T) { icmp.SetIdent(0xdead) icmp.SetSequence(0xbeef) - // Create the inner IPv6 header. - ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) - ip.Encode(&header.IPv6Fields{ - PayloadLength: 100, - NextHeader: 10, - HopLimit: 20, - SrcAddr: localIPv6Addr, - DstAddr: remoteIPv6Addr, - }) - + var extHdrs header.IPv6ExtHdrSerializer // Build the fragmentation header if needed. if c.fragmentOffset != nil { - ip.SetNextHeader(header.IPv6FragmentHeader) - frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:]) - frag.Encode(&header.IPv6FragmentFields{ - NextHeader: 10, + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{ FragmentOffset: *c.fragmentOffset, M: true, Identification: 0x12345678, }) } + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: localIPv6Addr, + DstAddr: remoteIPv6Addr, + ExtensionHeaders: extHdrs, + }) + // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { view[i] = uint8(i) @@ -1344,10 +1343,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1387,10 +1386,12 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { } ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + // NB: we're lying about transport protocol here to verify the raw + // fragment header bytes. + TransportProtocol: tcpip.TransportProtocolNumber(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return hdr.View().ToVectorisedView() }, @@ -1422,10 +1423,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip).ToVectorisedView() }, @@ -1457,10 +1458,10 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - NextHeader: transportProto, - HopLimit: ipv6.DefaultTTL, - SrcAddr: src, - DstAddr: header.IPv4Any, + TransportProtocol: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, }) return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index 0134fadc0..da88d65d1 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -16,7 +16,6 @@ package ipv4 import ( "fmt" - "sync" "sync/atomic" "time" @@ -58,6 +57,9 @@ type IGMPOptions struct { // When enabled, IGMP may transmit IGMP report and leave messages when // joining and leaving multicast groups respectively, and handle incoming // IGMP packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). Enabled bool } @@ -68,8 +70,9 @@ var _ ip.MulticastGroupProtocol = (*igmpState)(nil) // igmpState.init() MUST be called after creating an IGMP state. type igmpState struct { // The IPv4 endpoint this igmpState is for. - ep *endpoint - opts IGMPOptions + ep *endpoint + + genericMulticastProtocol ip.GenericMulticastProtocolState // igmpV1Present is for maintaining compatibility with IGMPv1 Routers, from // RFC 2236 Section 4 Page 6: "The IGMPv1 router expects Version 1 @@ -84,20 +87,23 @@ type igmpState struct { // when false. igmpV1Present uint32 - mu struct { - sync.RWMutex - - genericMulticastProtocol ip.GenericMulticastProtocolState + // igmpV1Job is scheduled when this interface receives an IGMPv1 style + // message, upon expiration the igmpV1Present flag is cleared. + // igmpV1Job may not be nil once igmpState is initialized. + igmpV1Job *tcpip.Job +} - // igmpV1Job is scheduled when this interface receives an IGMPv1 style - // message, upon expiration the igmpV1Present flag is cleared. - // igmpV1Job may not be nil once igmpState is initialized. - igmpV1Job *tcpip.Job - } +// Enabled implements ip.MulticastGroupProtocol. +func (igmp *igmpState) Enabled() bool { + // No need to perform IGMP on loopback interfaces since they don't have + // neighbouring nodes. + return igmp.ep.protocol.options.IGMP.Enabled && !igmp.ep.nic.IsLoopback() && igmp.ep.Enabled() } // SendReport implements ip.MulticastGroupProtocol. -func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { igmpType := header.IGMPv2MembershipReport if igmp.v1Present() { igmpType = header.IGMPv1MembershipReport @@ -106,6 +112,8 @@ func (igmp *igmpState) SendReport(groupAddress tcpip.Address) *tcpip.Error { } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { // As per RFC 2236 Section 6, Page 8: "If the interface state says the // Querier is running IGMPv1, this action SHOULD be skipped. If the flag @@ -114,18 +122,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { if igmp.v1Present() { return nil } - return igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + _, err := igmp.writePacket(header.IPv4AllRoutersGroup, groupAddress, header.IGMPLeaveGroup) + return err } // init sets up an igmpState struct, and is required to be called before using // a new igmpState. -func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { - igmp.mu.Lock() - defer igmp.mu.Unlock() +// +// Must only be called once for the lifetime of igmp. +func (igmp *igmpState) init(ep *endpoint) { igmp.ep = ep - igmp.opts = opts - igmp.mu.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + igmp.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: igmp, @@ -133,11 +140,14 @@ func (igmp *igmpState) init(ep *endpoint, opts IGMPOptions) { AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault - igmp.mu.igmpV1Job = igmp.ep.protocol.stack.NewJob(&igmp.mu, func() { + igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { igmp.setV1Present(false) }) } +// handleIGMP handles an IGMP packet. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleIGMP(pkt *stack.PacketBuffer) { stats := igmp.ep.protocol.stack.Stats() received := stats.IGMP.PacketsReceived @@ -207,32 +217,34 @@ func (igmp *igmpState) setV1Present(v bool) { } } +// handleMembershipQuery handles a membership query. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipQuery(groupAddress tcpip.Address, maxRespTime time.Duration) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // As per RFC 2236 Section 6, Page 10: If the maximum response time is zero // then change the state to note that an IGMPv1 router is present and // schedule the query received Job. - if maxRespTime == 0 && igmp.opts.Enabled { - igmp.mu.igmpV1Job.Cancel() - igmp.mu.igmpV1Job.Schedule(v1RouterPresentTimeout) + if maxRespTime == 0 && igmp.Enabled() { + igmp.igmpV1Job.Cancel() + igmp.igmpV1Job.Schedule(v1RouterPresentTimeout) igmp.setV1Present(true) maxRespTime = v1MaxRespTime } - igmp.mu.genericMulticastProtocol.HandleQuery(groupAddress, maxRespTime) + igmp.genericMulticastProtocol.HandleQueryLocked(groupAddress, maxRespTime) } +// handleMembershipReport handles a membership report. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) handleMembershipReport(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.HandleReport(groupAddress) + igmp.genericMulticastProtocol.HandleReportLocked(groupAddress) } -// writePacket assembles and sends an IGMP packet with the provided fields, -// incrementing the provided stat counter on success. -func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) *tcpip.Error { +// writePacket assembles and sends an IGMP packet. +// +// Precondition: igmp.ep.mu must be read locked. +func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip.Address, igmpType header.IGMPType) (bool, *tcpip.Error) { igmpData := header.IGMP(buffer.NewView(header.IGMPReportMinimumSize)) igmpData.SetType(igmpType) igmpData.SetGroupAddress(groupAddress) @@ -243,9 +255,13 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip Data: buffer.View(igmpData).ToVectorisedView(), }) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddr := header.IPv4Any + addressEndpoint := igmp.ep.acquireOutgoingPrimaryAddressRLocked(destAddress, false /* allowExpired */) + if addressEndpoint == nil { + return false, nil + } + localAddr := addressEndpoint.AddressWithPrefix().Address + addressEndpoint.DecRef() + addressEndpoint = nil igmp.ep.addIPHeader(localAddr, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.IGMPProtocolNumber, TTL: header.IGMPTTL, @@ -254,22 +270,22 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip &header.IPv4SerializableRouterAlertOption{}, }) - sent := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent + sentStats := igmp.ep.protocol.stack.Stats().IGMP.PacketsSent if err := igmp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv4Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { - sent.Dropped.Increment() - return err + sentStats.Dropped.Increment() + return false, err } switch igmpType { case header.IGMPv1MembershipReport: - sent.V1MembershipReport.Increment() + sentStats.V1MembershipReport.Increment() case header.IGMPv2MembershipReport: - sent.V2MembershipReport.Increment() + sentStats.V2MembershipReport.Increment() case header.IGMPLeaveGroup: - sent.LeaveGroup.Increment() + sentStats.LeaveGroup.Increment() default: panic(fmt.Sprintf("unrecognized igmp type = %d", igmpType)) } - return nil + return true, nil } // joinGroup handles adding a new group to the membership map, setting up the @@ -278,28 +294,27 @@ func (igmp *igmpState) writePacket(destAddress tcpip.Address, groupAddress tcpip // // If the group already exists in the membership map, returns // tcpip.ErrDuplicateAddress. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) joinGroup(groupAddress tcpip.Address) { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.JoinGroup(groupAddress, !igmp.ep.Enabled() /* dontInitialize */) + igmp.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: igmp.ep.mu must be read locked. func (igmp *igmpState) isInGroup(groupAddress tcpip.Address) bool { - igmp.mu.Lock() - defer igmp.mu.Unlock() - return igmp.mu.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return igmp.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Leave Group message // if required. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { - igmp.mu.Lock() - defer igmp.mu.Unlock() - // LeaveGroup returns false only if the group was not joined. - if igmp.mu.genericMulticastProtocol.LeaveGroup(groupAddress) { + if igmp.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -308,16 +323,23 @@ func (igmp *igmpState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of IGMP, but remains // joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) softLeaveAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.MakeAllNonMember() + igmp.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the IGMP state for each group that has // been joined locally. +// +// Precondition: igmp.ep.mu must be locked. func (igmp *igmpState) initializeAll() { - igmp.mu.Lock() - defer igmp.mu.Unlock() - igmp.mu.genericMulticastProtocol.InitializeGroups() + igmp.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: igmp.ep.mu must be locked. +func (igmp *igmpState) sendQueuedReports() { + igmp.genericMulticastProtocol.SendQueuedReportsLocked() } diff --git a/pkg/tcpip/network/ipv4/igmp_test.go b/pkg/tcpip/network/ipv4/igmp_test.go index 5e139377b..1ee573ac8 100644 --- a/pkg/tcpip/network/ipv4/igmp_test.go +++ b/pkg/tcpip/network/ipv4/igmp_test.go @@ -16,6 +16,7 @@ package ipv4_test import ( "testing" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -29,6 +30,7 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + addr = tcpip.Address("\x0a\x00\x00\x01") multicastAddr = tcpip.Address("\xe0\x00\x00\x03") nicID = 1 ) @@ -41,6 +43,7 @@ func validateIgmpPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -71,7 +74,6 @@ func createStack(t *testing.T, igmpEnabled bool) (*channel.Endpoint, *stack.Stac if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - return e, s, clock } @@ -104,6 +106,9 @@ func createAndInjectIGMPPacket(e *channel.Endpoint, igmpType header.IGMPType, ma // reports for backwards compatibility. func TestIgmpV1Present(t *testing.T) { e, s, clock := createStack(t, true) + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { t.Fatalf("JoinGroup(ipv4, nic, %s) = %s", multicastAddr, err) @@ -154,3 +159,57 @@ func TestIgmpV1Present(t *testing.T) { } validateIgmpPacket(t, p, multicastAddr, header.IGMPv1MembershipReport, 0, multicastAddr) } + +func TestSendQueuedIGMPReports(t *testing.T) { + e, s, clock := createStack(t, true) + + // Joining a group without an assigned address should queue IGMP packets; none + // should be sent without an assigned address. + if err := s.JoinGroup(ipv4.ProtocolNumber, nicID, multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv4.ProtocolNumber, nicID, multicastAddr, err) + } + reportStat := s.Stats().IGMP.PacketsSent.V2MembershipReport + if got := reportStat.Value(); got != 0 { + t.Errorf("got reportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } + + // The initial set of IGMP reports that were queued should be sent once an + // address is assigned. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, addr, err) + } + if got := reportStat.Value(); got != 1 { + t.Errorf("got reportStat.Value() = %d, want = 1", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + clock.Advance(ipv4.UnsolicitedReportIntervalMax) + if got := reportStat.Value(); got != 2 { + t.Errorf("got reportStat.Value() = %d, want = 2", got) + } + if p, ok := e.Read(); !ok { + t.Error("expected to send an IGMP membership report") + } else { + validateIgmpPacket(t, p, multicastAddr, header.IGMPv2MembershipReport, 0, multicastAddr) + } + if t.Failed() { + t.FailNow() + } + + // Should have no more packets to send after the initial set of unsolicited + // reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("got unexpected packet = %#v", p) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 3076185cd..e9ff70d04 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -72,7 +72,6 @@ type endpoint struct { nic stack.NetworkInterface dispatcher stack.TransportDispatcher protocol *protocol - igmp igmpState // enabled is set to 1 when the enpoint is enabled and 0 when it is // disabled. @@ -84,6 +83,7 @@ type endpoint struct { sync.RWMutex addressableEndpointState stack.AddressableEndpointState + igmp igmpState } } @@ -94,8 +94,10 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, _ stack.LinkAddressCa dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.igmp.init(e, p.options.IGMP) + e.mu.igmp.init(e) + e.mu.Unlock() return e } @@ -127,7 +129,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of IGMP when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.igmp.initializeAll() + e.mu.igmp.initializeAll() // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts // multicast group. Note, the IANA calls the all-hosts multicast group the @@ -170,7 +172,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.isEnabled() { return } @@ -181,12 +183,16 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of IGMP so that routers know that // we are no longer interested in the group. - e.igmp.softLeaveAll() + e.mu.igmp.softLeaveAll() // The address may have already been removed. if err := e.mu.addressableEndpointState.RemovePermanentAddress(ipv4BroadcastAddr.Address); err != nil && err != tcpip.ErrBadLocalAddress { panic(fmt.Sprintf("unexpected error when removing address = %s: %s", ipv4BroadcastAddr.Address, err)) } + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // DefaultTTL is the default time-to-live value for this endpoint. @@ -718,7 +724,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { return } if p == header.IGMPProtocolNumber { - e.igmp.handleIGMP(pkt) + e.mu.Lock() + e.mu.igmp.handleIGMP(pkt) + e.mu.Unlock() return } if opts := h.Options(); len(opts) != 0 { @@ -776,7 +784,12 @@ func (e *endpoint) Close() { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, *tcpip.Error) { e.mu.Lock() defer e.mu.Unlock() - return e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + + ep, err := e.mu.addressableEndpointState.AddAndAcquirePermanentAddress(addr, peb, configType, deprecated) + if err == nil { + e.mu.igmp.sendQueuedReports() + } + return ep, err } // RemovePermanentAddress implements stack.AddressableEndpoint. @@ -811,6 +824,14 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { e.mu.RLock() defer e.mu.RUnlock() + return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) +} + +// acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress +// but with locking requirements +// +// Precondition: igmp.ep.mu must be read locked. +func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address, allowExpired bool) stack.AddressEndpoint { return e.mu.addressableEndpointState.AcquireOutgoingPrimaryAddress(remoteAddr, allowExpired) } @@ -843,7 +864,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.igmp.joinGroup(addr) + e.mu.igmp.joinGroup(addr) return nil } @@ -858,14 +879,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.igmp.leaveGroup(addr) + return e.mu.igmp.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.igmp.isInGroup(addr) + return e.mu.igmp.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9e2d2cfd6..ef62fe6fc 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2669,8 +2669,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2712,8 +2712,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2761,8 +2761,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) + if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 5e75c8740..afa45aefe 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -58,7 +58,10 @@ go_test( srcs = ["mld_test.go"], deps = [ ":ipv6", + "//pkg/tcpip", + "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 510276b8e..6ee162713 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -645,26 +645,34 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { } case header.ICMPv6MulticastListenerQuery, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone: - var handler func(header.MLD) switch icmpType { case header.ICMPv6MulticastListenerQuery: received.MulticastListenerQuery.Increment() - handler = e.mld.handleMulticastListenerQuery case header.ICMPv6MulticastListenerReport: received.MulticastListenerReport.Increment() - handler = e.mld.handleMulticastListenerReport case header.ICMPv6MulticastListenerDone: received.MulticastListenerDone.Increment() default: panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } + if pkt.Data.Size()-header.ICMPv6HeaderSize < header.MLDMinimumSize { received.Invalid.Increment() return } - if handler != nil { - handler(header.MLD(payload.ToView())) + switch icmpType { + case header.ICMPv6MulticastListenerQuery: + e.mu.Lock() + e.mu.mld.handleMulticastListenerQuery(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerReport: + e.mu.Lock() + e.mu.mld.handleMulticastListenerReport(header.MLD(payload.ToView())) + e.mu.Unlock() + case header.ICMPv6MulticastListenerDone: + default: + panic(fmt.Sprintf("unrecognized MLD message = %d", icmpType)) } default: diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 32adb5c83..34a6a8446 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -149,9 +149,8 @@ func (*testInterface) Promiscuous() bool { } func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - r := stack.Route{ - NetProto: protocol, - } + var r stack.Route + r.NetProto = protocol r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -296,11 +295,11 @@ func TestICMPCounts(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -454,11 +453,11 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) ep.HandlePacket(pkt) } @@ -600,8 +599,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { - t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) + if len(args.remoteLinkAddr) != 0 && pi.Route.RemoteLinkAddress != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -853,11 +852,11 @@ func TestICMPChecksumValidationSimple(t *testing.T) { } ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), @@ -930,11 +929,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1048,11 +1047,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(icmpSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1108,11 +1107,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { errorICMPBody := func(view buffer.View) { ip := header.IPv6(view) ip.Encode(&header.IPv6Fields{ - PayloadLength: simpleBodySize, - NextHeader: 10, - HopLimit: 20, - SrcAddr: lladdr0, - DstAddr: lladdr1, + PayloadLength: simpleBodySize, + TransportProtocol: 10, + HopLimit: 20, + SrcAddr: lladdr0, + DstAddr: lladdr1, }) simpleBody(view[header.IPv6MinimumSize:]) } @@ -1227,11 +1226,11 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(size + payloadSize), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(size + payloadSize), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(header.IPv6MinimumSize+size+payloadSize, []buffer.View{hdr.View(), payload}), @@ -1381,8 +1380,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1445,11 +1444,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1463,8 +1462,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1487,11 +1486,11 @@ func TestPacketQueing(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: DefaultTTL, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: DefaultTTL, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1505,8 +1504,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) + if p.Route.RemoteLinkAddress != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1556,8 +1555,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1586,11 +1585,11 @@ func TestPacketQueing(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, - DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -1828,11 +1827,11 @@ func TestCallsToNeighborCache(t *testing.T) { }) ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.source, - DstAddr: test.destination, + PayloadLength: uint16(len(icmp)), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.source, + DstAddr: test.destination, }) ep.HandlePacket(pkt) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 8bf84601f..f2018d073 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -19,6 +19,7 @@ import ( "encoding/binary" "fmt" "hash/fnv" + "math" "sort" "sync/atomic" "time" @@ -60,6 +61,108 @@ const ( buckets = 2048 ) +// policyTable is the default policy table defined in RFC 6724 section 2.1. +// +// A more human-readable version: +// +// Prefix Precedence Label +// ::1/128 50 0 +// ::/0 40 1 +// ::ffff:0:0/96 35 4 +// 2002::/16 30 2 +// 2001::/32 5 5 +// fc00::/7 3 13 +// ::/96 1 3 +// fec0::/10 1 11 +// 3ffe::/16 1 12 +// +// The table is sorted by prefix length so longest-prefix match can be easily +// achieved. +// +// We willingly left out ::/96, fec0::/10 and 3ffe::/16 since those prefix +// assignments are deprecated. +// +// As per RFC 4291 section 2.5.5.1 (for ::/96), +// +// The "IPv4-Compatible IPv6 address" is now deprecated because the +// current IPv6 transition mechanisms no longer use these addresses. +// New or updated implementations are not required to support this +// address type. +// +// As per RFC 3879 section 4 (for fec0::/10), +// +// This document formally deprecates the IPv6 site-local unicast prefix +// defined in [RFC3513], i.e., 1111111011 binary or FEC0::/10. +// +// As per RFC 3701 section 1 (for 3ffe::/16), +// +// As clearly stated in [TEST-NEW], the addresses for the 6bone are +// temporary and will be reclaimed in the future. It further states +// that all users of these addresses (within the 3FFE::/16 prefix) will +// be required to renumber at some time in the future. +// +// and section 2, +// +// Thus after the pTLA allocation cutoff date January 1, 2004, it is +// REQUIRED that no new 6bone 3FFE pTLAs be allocated. +// +// MUST NOT BE MODIFIED. +var policyTable = [...]struct { + subnet tcpip.Subnet + + label uint8 +}{ + // ::1/128 + { + subnet: header.IPv6Loopback.WithPrefix().Subnet(), + label: 0, + }, + // ::ffff:0:0/96 + { + subnet: header.IPv4MappedIPv6Subnet, + label: 4, + }, + // 2001::/32 (Teredo prefix as per RFC 4380 section 2.6). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 32, + }.Subnet(), + label: 5, + }, + // 2002::/16 (6to4 prefix as per RFC 3056 section 2). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 16, + }.Subnet(), + label: 2, + }, + // fc00::/7 (Unique local addresses as per RFC 4193 section 3.1). + { + subnet: tcpip.AddressWithPrefix{ + Address: "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + PrefixLen: 7, + }.Subnet(), + label: 13, + }, + // ::/0 + { + subnet: header.IPv6EmptySubnet, + label: 1, + }, +} + +func getLabel(addr tcpip.Address) uint8 { + for _, p := range policyTable { + if p.subnet.Contains(addr) { + return p.label + } + } + + panic(fmt.Sprintf("should have a label for address = %s", addr)) +} + var _ stack.GroupAddressableEndpoint = (*endpoint)(nil) var _ stack.AddressableEndpoint = (*endpoint)(nil) var _ stack.NetworkEndpoint = (*endpoint)(nil) @@ -85,9 +188,8 @@ type endpoint struct { addressableEndpointState stack.AddressableEndpointState ndp ndpState + mld mldState } - - mld mldState } // NICNameFromID is a function that returns a stable name for the specified NIC, @@ -122,6 +224,45 @@ type OpaqueInterfaceIdentifierOptions struct { SecretKey []byte } +// onAddressAssignedLocked handles an address being assigned. +// +// Precondition: e.mu must be exclusively locked. +func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, ... + // + // If we just completed DAD for a link-local address, then attempt to send any + // queued MLD reports. Note, we may have sent reports already for some of the + // groups before we had a valid link-local address to use as the source for + // the MLD messages, but that was only so that MLD snooping switches are aware + // of our membership to groups - routers would not have handled those reports. + // + // As per RFC 3590 section 4, + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + if header.IsV6LinkLocalAddress(addr) { + e.mu.mld.sendQueuedReports() + } +} + // InvalidateDefaultRouter implements stack.NDPEndpoint. func (e *endpoint) InvalidateDefaultRouter(rtr tcpip.Address) { e.mu.Lock() @@ -232,7 +373,7 @@ func (e *endpoint) Enable() *tcpip.Error { // endpoint may have left groups from the perspective of MLD when the // endpoint was disabled. Either way, we need to let routers know to // send us multicast traffic. - e.mld.initializeAll() + e.mu.mld.initializeAll() // Join the IPv6 All-Nodes Multicast group if the stack is configured to // use IPv6. This is required to ensure that this node properly receives @@ -334,7 +475,7 @@ func (e *endpoint) Disable() { } func (e *endpoint) disableLocked() { - if !e.setEnabled(false) { + if !e.Enabled() { return } @@ -349,7 +490,11 @@ func (e *endpoint) disableLocked() { // Leave groups from the perspective of MLD so that routers know that // we are no longer interested in the group. - e.mld.softLeaveAll() + e.mu.mld.softLeaveAll() + + if !e.setEnabled(false) { + panic("should have only done work to disable the endpoint if it was enabled") + } } // stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. @@ -389,19 +534,27 @@ func (e *endpoint) MTU() uint32 { // MaxHeaderLength returns the maximum length needed by ipv6 headers (and // underlying protocols). func (e *endpoint) MaxHeaderLength() uint16 { + // TODO(gvisor.dev/issues/5035): The maximum header length returned here does + // not open the possibility for the caller to know about size required for + // extension headers. return e.nic.MaxHeaderLength() + header.IPv6MinimumSize } -func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) +func (e *endpoint) addIPHeader(srcAddr, dstAddr tcpip.Address, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams, extensionHeaders header.IPv6ExtHdrSerializer) { + extHdrsLen := extensionHeaders.Length() + length := pkt.Size() + extensionHeaders.Length() + if length > math.MaxUint16 { + panic(fmt.Sprintf("IPv6 payload too large: %d, must be <= %d", length, math.MaxUint16)) + } + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(params.Protocol), - HopLimit: params.TTL, - TrafficClass: params.TOS, - SrcAddr: srcAddr, - DstAddr: dstAddr, + PayloadLength: uint16(length), + TransportProtocol: params.Protocol, + HopLimit: params.TTL, + TrafficClass: params.TOS, + SrcAddr: srcAddr, + DstAddr: dstAddr, + ExtensionHeaders: extensionHeaders, }) pkt.NetworkProtocolNumber = ProtocolNumber } @@ -456,7 +609,7 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU ui // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pkt, params, nil /* extensionHeaders */) // iptables filtering. All packets that reach here are locally // generated. @@ -545,7 +698,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { - e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params) + e.addIPHeader(r.LocalAddress, r.RemoteAddress, pb, params, nil /* extensionHeaders */) networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) if err != nil { @@ -1177,13 +1330,6 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre return addressEndpoint, nil } - snmc := header.SolicitedNodeAddr(addr.Address) - if err := e.joinGroupLocked(snmc); err != nil { - // joinGroupLocked only returns an error if the group address is not a valid - // IPv6 multicast address. - panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) - } - addressEndpoint.SetKind(stack.PermanentTentative) if e.Enabled() { @@ -1192,6 +1338,13 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre } } + snmc := header.SolicitedNodeAddr(addr.Address) + if err := e.joinGroupLocked(snmc); err != nil { + // joinGroupLocked only returns an error if the group address is not a valid + // IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", snmc, err)) + } + return addressEndpoint, nil } @@ -1293,6 +1446,26 @@ func (e *endpoint) AcquireOutgoingPrimaryAddress(remoteAddr tcpip.Address, allow return e.acquireOutgoingPrimaryAddressRLocked(remoteAddr, allowExpired) } +// getLinkLocalAddressRLocked returns a link-local address from the primary list +// of addresses, if one is available. +// +// See stack.PrimaryEndpointBehavior for more details about the primary list. +// +// Precondition: e.mu must be read locked. +func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { + var linkLocalAddr tcpip.Address + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.IsAssigned(false /* allowExpired */) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + linkLocalAddr = addr + return false + } + } + return true + }) + return linkLocalAddr +} + // acquireOutgoingPrimaryAddressRLocked is like AcquireOutgoingPrimaryAddress // but with locking requirements. // @@ -1302,7 +1475,11 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // RFC 6724 section 5. type addrCandidate struct { addressEndpoint stack.AddressEndpoint + addr tcpip.Address scope header.IPv6AddressScope + + label uint8 + matchingPrefix uint8 } if len(remoteAddr) == 0 { @@ -1312,10 +1489,10 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address // Create a candidate set of available addresses we can potentially use as a // source address. var cs []addrCandidate - e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) { + e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { // If r is not valid for outgoing connections, it is not a valid endpoint. if !addressEndpoint.IsAssigned(allowExpired) { - return + return true } addr := addressEndpoint.AddressWithPrefix().Address @@ -1329,8 +1506,13 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address cs = append(cs, addrCandidate{ addressEndpoint: addressEndpoint, + addr: addr, scope: scope, + label: getLabel(addr), + matchingPrefix: remoteAddr.MatchingPrefix(addr), }) + + return true }) remoteScope, err := header.ScopeForIPv6Address(remoteAddr) @@ -1339,18 +1521,20 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address panic(fmt.Sprintf("header.ScopeForIPv6Address(%s): %s", remoteAddr, err)) } + remoteLabel := getLabel(remoteAddr) + // Sort the addresses as per RFC 6724 section 5 rules 1-3. // - // TODO(b/146021396): Implement rules 4-8 of RFC 6724 section 5. + // TODO(b/146021396): Implement rules 4, 5 of RFC 6724 section 5. sort.Slice(cs, func(i, j int) bool { sa := cs[i] sb := cs[j] // Prefer same address as per RFC 6724 section 5 rule 1. - if sa.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sa.addr == remoteAddr { return true } - if sb.addressEndpoint.AddressWithPrefix().Address == remoteAddr { + if sb.addr == remoteAddr { return false } @@ -1367,11 +1551,29 @@ func (e *endpoint) acquireOutgoingPrimaryAddressRLocked(remoteAddr tcpip.Address return sbDep } + // Prefer matching label as per RFC 6724 section 5 rule 6. + if sa, sb := sa.label == remoteLabel, sb.label == remoteLabel; sa != sb { + if sa { + return true + } + if sb { + return false + } + } + // Prefer temporary addresses as per RFC 6724 section 5 rule 7. if saTemp, sbTemp := sa.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp, sb.addressEndpoint.ConfigType() == stack.AddressConfigSlaacTemp; saTemp != sbTemp { return saTemp } + // Use longest matching prefix as per RFC 6724 section 5 rule 8. + if sa.matchingPrefix > sb.matchingPrefix { + return true + } + if sb.matchingPrefix > sa.matchingPrefix { + return false + } + // sa and sb are equal, return the endpoint that is closest to the front of // the primary endpoint list. return i < j @@ -1417,7 +1619,7 @@ func (e *endpoint) joinGroupLocked(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadAddress } - e.mld.joinGroup(addr) + e.mu.mld.joinGroup(addr) return nil } @@ -1432,14 +1634,14 @@ func (e *endpoint) LeaveGroup(addr tcpip.Address) *tcpip.Error { // // Precondition: e.mu must be locked. func (e *endpoint) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { - return e.mld.leaveGroup(addr) + return e.mu.mld.leaveGroup(addr) } // IsInGroup implements stack.GroupAddressableEndpoint. func (e *endpoint) IsInGroup(addr tcpip.Address) bool { e.mu.RLock() defer e.mu.RUnlock() - return e.mld.isInGroup(addr) + return e.mu.mld.isInGroup(addr) } var _ stack.ForwardingNetworkProtocol = (*protocol)(nil) @@ -1504,17 +1706,11 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, linkAddrCache stack.L dispatcher: dispatcher, protocol: p, } + e.mu.Lock() e.mu.addressableEndpointState.Init(e) - e.mu.ndp = ndpState{ - ep: e, - configs: p.options.NDPConfigs, - dad: make(map[tcpip.Address]dadState), - defaultRouters: make(map[tcpip.Address]defaultRouterState), - onLinkPrefixes: make(map[tcpip.Subnet]onLinkPrefixState), - slaacPrefixes: make(map[tcpip.Subnet]slaacPrefixState), - } - e.mu.ndp.initializeTempAddrState() - e.mld.init(e, p.options.MLD) + e.mu.ndp.init(e) + e.mu.mld.init(e) + e.mu.Unlock() p.mu.Lock() defer p.mu.Unlock() @@ -1735,24 +1931,25 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea fragPkt.NetworkProtocolNumber = ProtocolNumber originalIPHeadersLength := len(originalIPHeaders) - fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize + + s := header.IPv6ExtHdrSerializer{&header.IPv6SerializableFragmentExtHdr{ + FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + M: more, + Identification: id, + }} + + fragmentIPHeadersLength := originalIPHeadersLength + s.Length() fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) - fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got %d, want %d", copied, originalIPHeadersLength)) } - fragmentIPHeaders.SetNextHeader(header.IPv6FragmentHeader) - fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) - fragmentHeader := header.IPv6Fragment(fragmentIPHeaders[originalIPHeadersLength:]) - fragmentHeader.Encode(&header.IPv6FragmentFields{ - M: more, - FragmentOffset: uint16(offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), - Identification: id, - NextHeader: uint8(transportProto), - }) + nextHeader, _ := s.Serialize(transportProto, fragmentIPHeaders[originalIPHeadersLength:]) + + fragmentIPHeaders.SetNextHeader(nextHeader) + fragmentIPHeaders.SetPayloadLength(uint16(copied + fragmentIPHeadersLength - header.IPv6MinimumSize)) return fragPkt, more } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 1c01f17ab..5f07d3af8 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -69,11 +69,11 @@ func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -127,11 +127,11 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 255, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(payloadLength), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 255, + SrcAddr: src, + DstAddr: dst, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -915,10 +915,12 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(payloadLength), - NextHeader: ipv6NextHdr, - HopLimit: 255, - SrcAddr: addr1, - DstAddr: dstAddr, + // We're lying about transport protocol here to be able to generate + // raw extension headers from the test definitions. + TransportProtocol: tcpip.TransportProtocolNumber(ipv6NextHdr), + HopLimit: 255, + SrcAddr: addr1, + DstAddr: dstAddr, }) e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1947,10 +1949,12 @@ func TestReceiveIPv6Fragments(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, + // We're lying about transport protocol here so that we can generate + // raw extension headers for the tests. + TransportProtocol: tcpip.TransportProtocolNumber(f.nextHdr), + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, }) vv := hdr.View().ToVectorisedView() @@ -1995,7 +1999,7 @@ func TestInvalidIPv6Fragments(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2014,14 +2018,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 9, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 9, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0 >> 3, M: true, Identification: ident, @@ -2041,14 +2044,13 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, M: false, Identification: ident, @@ -2089,10 +2091,9 @@ func TestInvalidIPv6Fragments(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) - - fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2154,7 +2155,7 @@ func TestFragmentReassemblyTimeout(t *testing.T) { type fragmentData struct { ipv6Fields header.IPv6Fields - ipv6FragmentFields header.IPv6FragmentFields + ipv6FragmentFields header.IPv6SerializableFragmentExtHdr payload []byte } @@ -2168,14 +2169,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2190,14 +2190,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2206,14 +2205,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2228,14 +2226,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2250,14 +2247,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2266,14 +2262,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2288,14 +2283,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { fragments: []fragmentData{ { ipv6Fields: header.IPv6Fields{ - PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 8, M: false, Identification: ident, @@ -2304,14 +2298,13 @@ func TestFragmentReassemblyTimeout(t *testing.T) { }, { ipv6Fields: header.IPv6Fields{ - PayloadLength: header.IPv6FragmentHeaderSize + 16, - NextHeader: header.IPv6FragmentHeader, - HopLimit: hoplimit, - SrcAddr: addr1, - DstAddr: addr2, + PayloadLength: header.IPv6FragmentHeaderSize + 16, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, }, - ipv6FragmentFields: header.IPv6FragmentFields{ - NextHeader: uint8(header.UDPProtocolNumber), + ipv6FragmentFields: header.IPv6SerializableFragmentExtHdr{ FragmentOffset: 0, M: true, Identification: ident, @@ -2350,10 +2343,11 @@ func TestFragmentReassemblyTimeout(t *testing.T) { hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) - ip.Encode(&f.ipv6Fields) + encodeArgs := f.ipv6Fields + encodeArgs.ExtensionHeaders = append(encodeArgs.ExtensionHeaders, &f.ipv6FragmentFields) + ip.Encode(&encodeArgs) fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) - fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() vv.AppendView(f.payload) @@ -2994,11 +2988,11 @@ func TestForwarding(t *testing.T) { icmp.SetChecksum(header.ICMPv6Checksum(icmp, remoteIPv6Addr1, remoteIPv6Addr2, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: test.TTL, - SrcAddr: remoteIPv6Addr1, - DstAddr: remoteIPv6Addr2, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: test.TTL, + SrcAddr: remoteIPv6Addr1, + DstAddr: remoteIPv6Addr2, }) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index 4c06b3f0c..e8d1e7a79 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -40,6 +40,9 @@ type MLDOptions struct { // When enabled, MLD may transmit MLD report and done messages when // joining and leaving multicast groups respectively, and handle incoming // MLD packets. + // + // This field is ignored and is always assumed to be false for interfaces + // without neighbouring nodes (e.g. loopback). Enabled bool } @@ -55,22 +58,35 @@ type mldState struct { genericMulticastProtocol ip.GenericMulticastProtocolState } +// Enabled implements ip.MulticastGroupProtocol. +func (mld *mldState) Enabled() bool { + // No need to perform MLD on loopback interfaces since they don't have + // neighbouring nodes. + return mld.ep.protocol.options.MLD.Enabled && !mld.ep.nic.IsLoopback() && mld.ep.Enabled() +} + // SendReport implements ip.MulticastGroupProtocol. -func (mld *mldState) SendReport(groupAddress tcpip.Address) *tcpip.Error { +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, *tcpip.Error) { return mld.writePacket(groupAddress, groupAddress, header.ICMPv6MulticastListenerReport) } // SendLeave implements ip.MulticastGroupProtocol. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) *tcpip.Error { - return mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + return err } // init sets up an mldState struct, and is required to be called before using // a new mldState. -func (mld *mldState) init(ep *endpoint, opts MLDOptions) { +// +// Must only be called once for the lifetime of mld. +func (mld *mldState) init(ep *endpoint) { mld.ep = ep - mld.genericMulticastProtocol.Init(ip.GenericMulticastProtocolOptions{ - Enabled: opts.Enabled, + mld.genericMulticastProtocol.Init(&ep.mu.RWMutex, ip.GenericMulticastProtocolOptions{ Rand: ep.protocol.stack.Rand(), Clock: ep.protocol.stack.Clock(), Protocol: mld, @@ -79,33 +95,45 @@ func (mld *mldState) init(ep *endpoint, opts MLDOptions) { }) } +// handleMulticastListenerQuery handles a query message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerQuery(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleQuery(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) + mld.genericMulticastProtocol.HandleQueryLocked(mldHdr.MulticastAddress(), mldHdr.MaximumResponseDelay()) } +// handleMulticastListenerReport handles a report message. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) handleMulticastListenerReport(mldHdr header.MLD) { - mld.genericMulticastProtocol.HandleReport(mldHdr.MulticastAddress()) + mld.genericMulticastProtocol.HandleReportLocked(mldHdr.MulticastAddress()) } // joinGroup handles joining a new group and sending and scheduling the required // messages. // // If the group is already joined, returns tcpip.ErrDuplicateAddress. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) joinGroup(groupAddress tcpip.Address) { - mld.genericMulticastProtocol.JoinGroup(groupAddress, !mld.ep.Enabled() /* dontInitialize */) + mld.genericMulticastProtocol.JoinGroupLocked(groupAddress) } // isInGroup returns true if the specified group has been joined locally. +// +// Precondition: mld.ep.mu must be read locked. func (mld *mldState) isInGroup(groupAddress tcpip.Address) bool { - return mld.genericMulticastProtocol.IsLocallyJoined(groupAddress) + return mld.genericMulticastProtocol.IsLocallyJoinedRLocked(groupAddress) } // leaveGroup handles removing the group from the membership map, cancels any // delay timers associated with that group, and sends the Done message, if // required. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // LeaveGroup returns false only if the group was not joined. - if mld.genericMulticastProtocol.LeaveGroup(groupAddress) { + if mld.genericMulticastProtocol.LeaveGroupLocked(groupAddress) { return nil } @@ -114,17 +142,31 @@ func (mld *mldState) leaveGroup(groupAddress tcpip.Address) *tcpip.Error { // softLeaveAll leaves all groups from the perspective of MLD, but remains // joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) softLeaveAll() { - mld.genericMulticastProtocol.MakeAllNonMember() + mld.genericMulticastProtocol.MakeAllNonMemberLocked() } // initializeAll attemps to initialize the MLD state for each group that has // been joined locally. +// +// Precondition: mld.ep.mu must be locked. func (mld *mldState) initializeAll() { - mld.genericMulticastProtocol.InitializeGroups() + mld.genericMulticastProtocol.InitializeGroupsLocked() +} + +// sendQueuedReports attempts to send any reports that are queued for sending. +// +// Precondition: mld.ep.mu must be locked. +func (mld *mldState) sendQueuedReports() { + mld.genericMulticastProtocol.SendQueuedReportsLocked() } -func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) *tcpip.Error { +// writePacket assembles and sends an MLD packet. +// +// Precondition: mld.ep.mu must be read locked. +func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldType header.ICMPv6Type) (bool, *tcpip.Error) { sentStats := mld.ep.protocol.stack.Stats().ICMP.V6.PacketsSent var mldStat *tcpip.StatCounter switch mldType { @@ -139,26 +181,82 @@ func (mld *mldState) writePacket(destAddress, groupAddress tcpip.Address, mldTyp icmp := header.ICMPv6(buffer.NewView(header.ICMPv6HeaderSize + header.MLDMinimumSize)) icmp.SetType(mldType) header.MLD(icmp.MessageBody()).SetMulticastAddress(groupAddress) - // TODO(gvisor.dev/issue/4888): We should not use the unspecified address, - // rather we should select an appropriate local address. - localAddress := header.IPv6Any + // As per RFC 2710 section 3, + // + // All MLD messages described in this document are sent with a link-local + // IPv6 Source Address, an IPv6 Hop Limit of 1, and an IPv6 Router Alert + // option in a Hop-by-Hop Options header. + // + // However, this would cause problems with Duplicate Address Detection with + // the first address as MLD snooping switches may not send multicast traffic + // that DAD depends on to the node performing DAD without the MLD report, as + // documented in RFC 4816: + // + // Note that when a node joins a multicast address, it typically sends a + // Multicast Listener Discovery (MLD) report message [RFC2710] [RFC3810] + // for the multicast address. In the case of Duplicate Address + // Detection, the MLD report message is required in order to inform MLD- + // snooping switches, rather than routers, to forward multicast packets. + // In the above description, the delay for joining the multicast address + // thus means delaying transmission of the corresponding MLD report + // message. Since the MLD specifications do not request a random delay + // to avoid race conditions, just delaying Neighbor Solicitation would + // cause congestion by the MLD report messages. The congestion would + // then prevent the MLD-snooping switches from working correctly and, as + // a result, prevent Duplicate Address Detection from working. The + // requirement to include the delay for the MLD report in this case + // avoids this scenario. [RFC3590] also talks about some interaction + // issues between Duplicate Address Detection and MLD, and specifies + // which source address should be used for the MLD report in this case. + // + // As per RFC 3590 section 4, we should still send out MLD reports with an + // unspecified source address if we do not have an assigned link-local + // address to use as the source address to ensure DAD works as expected on + // networks with MLD snooping switches: + // + // MLD Report and Done messages are sent with a link-local address as + // the IPv6 source address, if a valid address is available on the + // interface. If a valid link-local address is not available (e.g., one + // has not been configured), the message is sent with the unspecified + // address (::) as the IPv6 source address. + // + // Once a valid link-local address is available, a node SHOULD generate + // new MLD Report messages for all multicast addresses joined on the + // interface. + // + // Routers receiving an MLD Report or Done message with the unspecified + // address as the IPv6 source address MUST silently discard the packet + // without taking any action on the packets contents. + // + // Snooping switches MUST manage multicast forwarding state based on MLD + // Report and Done messages sent with the unspecified address as the + // IPv6 source address. + localAddress := mld.ep.getLinkLocalAddressRLocked() + if len(localAddress) == 0 { + localAddress = header.IPv6Any + } + icmp.SetChecksum(header.ICMPv6Checksum(icmp, localAddress, destAddress, buffer.VectorisedView{})) + extensionHeaders := header.IPv6ExtHdrSerializer{ + header.IPv6SerializableHopByHopExtHdr{ + &header.IPv6RouterAlertOption{Value: header.IPv6RouterAlertMLD}, + }, + } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()), + ReserveHeaderBytes: int(mld.ep.MaxHeaderLength()) + extensionHeaders.Length(), Data: buffer.View(icmp).ToVectorisedView(), }) mld.ep.addIPHeader(localAddress, destAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.MLDHopLimit, - }) - // TODO(b/162198658): set the ROUTER_ALERT option when sending Host - // Membership Reports. + }, extensionHeaders) if err := mld.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(destAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sentStats.Dropped.Increment() - return err + return false, err } mldStat.Increment() - return nil + return localAddress != header.IPv6Any, nil } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 5677bdd54..e2778b656 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -16,8 +16,12 @@ package ipv6_test import ( "testing" + "time" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -25,9 +29,34 @@ import ( ) const ( - addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + linkLocalAddr = "\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalAddr = "\x0a\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + globalMulticastAddr = "\xff\x05\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" ) +var ( + linkLocalAddrSNMC = header.SolicitedNodeAddr(linkLocalAddr) + globalAddrSNMC = header.SolicitedNodeAddr(globalAddr) +) + +func validateMLDPacket(t *testing.T, p buffer.View, localAddress, remoteAddress tcpip.Address, mldType header.ICMPv6Type, groupAddress tcpip.Address) { + t.Helper() + + checker.IPv6WithExtHdr(t, p, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(localAddress), + checker.DstAddr(remoteAddress), + // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. + checker.TTL(1), + checker.MLD(mldType, header.MLDMinimumSize, + checker.MLDMaxRespDelay(0), + checker.MLDMulticastAddress(groupAddress), + ), + ) +} + func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { const nicID = 1 @@ -46,45 +75,223 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { // The stack will join an address's solicited node multicast address when // an address is added. An MLD report message should be sent for the // solicited-node group. - if err := s.AddAddress(nicID, ipv6.ProtocolNumber, addr1); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, addr1, err) + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a report message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(snmc), - // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerReport, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) } // The stack will leave an address's solicited node multicast address when // an address is removed. An MLD done message should be sent for the // solicited-node group. - if err := s.RemoveAddress(nicID, addr1); err != nil { - t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err) + if err := s.RemoveAddress(nicID, linkLocalAddr); err != nil { + t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a done message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } - { - p, ok := e.Read() - if !ok { - t.Fatal("expected a done message to be sent") - } - snmc := header.SolicitedNodeAddr(addr1) - checker.IPv6(t, header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), - checker.TTL(1), - checker.MLD(header.ICMPv6MulticastListenerDone, header.MLDMinimumSize, - checker.MLDMaxRespDelay(0), - checker.MLDMulticastAddress(snmc), - ), - ) +} + +func TestSendQueuedMLDReports(t *testing.T) { + const ( + nicID = 1 + maxReports = 2 + ) + + tests := []struct { + name string + dadTransmits uint8 + retransmitTimer time.Duration + }{ + { + name: "DAD Disabled", + dadTransmits: 0, + retransmitTimer: 0, + }, + { + name: "DAD Enabled", + dadTransmits: 1, + retransmitTimer: time.Second, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + dadResolutionTime := test.retransmitTimer * time.Duration(test.dadTransmits) + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + NDPConfigs: ipv6.NDPConfigurations{ + DupAddrDetectTransmits: test.dadTransmits, + RetransmitTimer: test.retransmitTimer, + }, + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + Clock: clock, + }) + + // Allow space for an extra packet so we can observe packets that were + // unexpectedly sent. + e := channel.New(maxReports+int(test.dadTransmits)+1 /* extra */, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + + resolveDAD := func(addr, snmc tcpip.Address) { + clock.Advance(dadResolutionTime) + if p, ok := e.Read(); !ok { + t.Fatal("expected DAD packet") + } else { + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(header.IPv6Any), + checker.DstAddr(snmc), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(addr), + checker.NDPNSOptions(nil), + )) + } + } + + var reportCounter uint64 + reportStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + var doneCounter uint64 + doneStat := s.Stats().ICMP.V6.PacketsSent.MulticastListenerDone + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + + // Joining a group without an assigned address should send an MLD report + // with the unspecified address. + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, globalMulticastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalMulticastAddr, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalMulticastAddr) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalMulticastAddr, header.ICMPv6MulticastListenerReport, globalMulticastAddr) + } + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a global address should not send reports for the already joined + // group since we should only send queued reports when a link-local + // addres sis assigned. + // + // Note, we will still expect to send a report for the global address's + // solicited node address from the unspecified address as per RFC 3590 + // section 4. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, globalAddr, stack.FirstPrimaryEndpoint, err) + } + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", globalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, globalAddrSNMC, header.ICMPv6MulticastListenerReport, globalAddrSNMC) + } + if dadResolutionTime != 0 { + // Reports should not be sent when the address resolves. + resolveDAD(globalAddr, globalAddrSNMC) + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + } + // Leave the group since we don't care about the global address's + // solicited node multicast group membership. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, globalAddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, globalAddrSNMC, err) + } + if got := doneStat.Value(); got != doneCounter { + t.Errorf("got doneStat.Value() = %d, want = %d", got, doneCounter) + } + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + if t.Failed() { + t.FailNow() + } + + // Adding a link-local address should send a report for its solicited node + // address and globalMulticastAddr. + if err := s.AddAddressWithOptions(nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("AddAddressWithOptions(%d, %d, %s, %d): %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, stack.CanBePrimaryEndpoint, err) + } + if dadResolutionTime != 0 { + reportCounter++ + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Errorf("expected MLD report for %s", linkLocalAddrSNMC) + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + resolveDAD(linkLocalAddr, linkLocalAddrSNMC) + } + + // We expect two batches of reports to be sent (1 batch when the + // link-local address is assigned, and another after the maximum + // unsolicited report interval. + for i := 0; i < 2; i++ { + // We expect reports to be sent (one for globalMulticastAddr and another + // for linkLocalAddrSNMC). + reportCounter += maxReports + if got := reportStat.Value(); got != reportCounter { + t.Errorf("got reportStat.Value() = %d, want = %d", got, reportCounter) + } + + addrs := map[tcpip.Address]bool{ + globalMulticastAddr: false, + linkLocalAddrSNMC: false, + } + for _ = range addrs { + p, ok := e.Read() + if !ok { + t.Fatalf("expected MLD report for %s and %s; addrs = %#v", globalMulticastAddr, linkLocalAddrSNMC, addrs) + } + + addr := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())).DestinationAddress() + if seen, ok := addrs[addr]; !ok { + t.Fatalf("got unexpected packet destined to %s", addr) + } else if seen { + t.Fatalf("got another packet destined to %s", addr) + } + + addrs[addr] = true + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, addr, header.ICMPv6MulticastListenerReport, addr) + + clock.Advance(ipv6.UnsolicitedReportIntervalMax) + } + } + + // Should not send any more reports. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Errorf("got unexpected packet = %#v", p) + } + }) } } diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 8cb7d4dab..d515eb622 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -20,6 +20,7 @@ import ( "math/rand" "time" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -459,6 +460,9 @@ func (c *NDPConfigurations) validate() { // ndpState is the per-interface NDP state. type ndpState struct { + // Do not allow overwriting this state. + _ sync.NoCopy + // The IPv6 endpoint this ndpState is for. ep *endpoint @@ -643,6 +647,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) } + ndp.ep.onAddressAssignedLocked(addr) return nil } @@ -686,12 +691,14 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, dadDone, err) } - // If DAD resolved for a stable SLAAC address, attempt generation of a - // temporary SLAAC address. - if dadDone && addressEndpoint.ConfigType() == stack.AddressConfigSlaac { - // Reset the generation attempts counter as we are starting the generation - // of a new address for the SLAAC prefix. - ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + if dadDone { + if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { + // Reset the generation attempts counter as we are starting the + // generation of a new address for the SLAAC prefix. + ndp.regenerateTempSLAACAddr(addressEndpoint.AddressWithPrefix().Subnet(), true /* resetGenAttempts */) + } + + ndp.ep.onAddressAssignedLocked(addr) } }), } @@ -728,7 +735,7 @@ func (ndp *ndpState) sendDADPacket(addr tcpip.Address, addressEndpoint stack.Add ndp.ep.addIPHeader(header.IPv6Any, snmc, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }) + }, nil /* extensionHeaders */) if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(snmc), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() @@ -1850,7 +1857,7 @@ func (ndp *ndpState) startSolicitingRouters() { ndp.ep.addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, - }) + }, nil /* extensionHeaders */) if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.Dropped.Increment() @@ -1884,11 +1891,19 @@ func (ndp *ndpState) stopSolicitingRouters() { ndp.rtrSolicitJob = nil } -// initializeTempAddrState initializes state related to temporary SLAAC -// addresses. -func (ndp *ndpState) initializeTempAddrState() { - header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) +func (ndp *ndpState) init(ep *endpoint) { + if ndp.dad != nil { + panic("attempted to initialize NDP state twice") + } + ndp.ep = ep + ndp.configs = ep.protocol.options.NDPConfigs + ndp.dad = make(map[tcpip.Address]dadState) + ndp.defaultRouters = make(map[tcpip.Address]defaultRouterState) + ndp.onLinkPrefixes = make(map[tcpip.Subnet]onLinkPrefixState) + ndp.slaacPrefixes = make(map[tcpip.Subnet]slaacPrefixState) + + header.InitialTempIID(ndp.temporaryIIDHistory[:], ndp.ep.protocol.options.TempIIDSeed, ndp.ep.nic.ID()) if MaxDesyncFactor != 0 { ndp.temporaryAddressDesyncFactor = time.Duration(rand.Int63n(int64(MaxDesyncFactor))) } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 95c626bb8..7ddb19c00 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -213,11 +213,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -319,11 +319,11 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -599,11 +599,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -681,11 +681,11 @@ func TestNeighorSolicitationResponse(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: test.nsSrc, - DstAddr: nicAddr, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, }) e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -785,11 +785,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -898,11 +898,11 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, }) invalid := s.Stats().ICMP.V6.PacketsReceived.Invalid @@ -979,29 +979,25 @@ func TestNDPValidation(t *testing.T) { } handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View + var extHdrs header.IPv6ExtHdrSerializer if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + extHdrs = append(extHdrs, &header.IPv6SerializableFragmentExtHdr{}) } + extHdrsLen := extHdrs.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + ReserveHeaderBytes: header.IPv6MinimumSize + extHdrsLen, Data: payload.ToVectorisedView(), }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + extHdrsLen)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, + PayloadLength: uint16(len(payload) + extHdrsLen), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: hopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + ExtensionHeaders: extHdrs, }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } ep.HandlePacket(pkt) } @@ -1351,11 +1347,11 @@ func TestRouterAdvertValidation(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, }) stats := s.Stats().ICMP.V6.PacketsReceived diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index 95fb67986..05d98a0a5 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -34,6 +35,9 @@ import ( const ( linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") + ipv4Addr = tcpip.Address("\x0a\x00\x00\x01") + ipv6Addr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv4MulticastAddr1 = tcpip.Address("\xe0\x00\x00\x03") ipv4MulticastAddr2 = tcpip.Address("\xe0\x00\x00\x04") ipv4MulticastAddr3 = tcpip.Address("\xe0\x00\x00\x05") @@ -48,6 +52,8 @@ const ( mldQuery = uint8(header.ICMPv6MulticastListenerQuery) mldReport = uint8(header.ICMPv6MulticastListenerReport) mldDone = uint8(header.ICMPv6MulticastListenerDone) + + maxUnsolicitedReports = 2 ) var ( @@ -61,6 +67,8 @@ var ( } return uint8(ipv4.UnsolicitedReportIntervalMax / decisecond) }() + + ipv6AddrSNMC = header.SolicitedNodeAddr(ipv6Addr) ) // validateMLDPacket checks that a passed PacketInfo is an IPv6 MLD packet @@ -69,7 +77,11 @@ func validateMLDPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip.A t.Helper() payload := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - checker.IPv6(t, payload, + checker.IPv6WithExtHdr(t, payload, + checker.IPv6ExtHdr( + checker.IPv6HopByHopExtensionHeader(checker.IPv6RouterAlert(header.IPv6RouterAlertMLD)), + ), + checker.SrcAddr(ipv6Addr), checker.DstAddr(remoteAddress), // Hop Limit for an MLD message must be 1 as per RFC 2710 section 3. checker.TTL(1), @@ -87,6 +99,7 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. payload := header.IPv4(stack.PayloadSince(p.Pkt.NetworkHeader())) checker.IPv4(t, payload, + checker.SrcAddr(ipv4Addr), checker.DstAddr(remoteAddress), // TTL for an IGMP message must be 1 as per RFC 2236 section 2. checker.TTL(1), @@ -99,23 +112,31 @@ func validateIGMPPacket(t *testing.T, p channel.PacketInfo, remoteAddress tcpip. ) } -func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { +func createStack(t *testing.T, v4, mgpEnabled bool) (*channel.Endpoint, *stack.Stack, *faketime.ManualClock) { t.Helper() - // Create an endpoint of queue size 2, since no more than 2 packets are ever - // queued in the tests in this file. - e := channel.New(2, 1280, linkAddr) + e := channel.New(maxUnsolicitedReports, header.IPv6MinimumMTU, linkAddr) + s, clock := createStackWithLinkEndpoint(t, v4, mgpEnabled, e) + return e, s, clock +} + +func createStackWithLinkEndpoint(t *testing.T, v4, mgpEnabled bool, e stack.LinkEndpoint) (*stack.Stack, *faketime.ManualClock) { + t.Helper() + + igmpEnabled := v4 && mgpEnabled + mldEnabled := !v4 && mgpEnabled + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocolWithOptions(ipv4.Options{ IGMP: ipv4.IGMPOptions{ - Enabled: mgpEnabled, + Enabled: igmpEnabled, }, }), ipv6.NewProtocolWithOptions(ipv6.Options{ MLD: ipv6.MLDOptions{ - Enabled: mgpEnabled, + Enabled: mldEnabled, }, }), }, @@ -124,8 +145,59 @@ func createStack(t *testing.T, mgpEnabled bool) (*channel.Endpoint, *stack.Stack if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, ipv4Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, ipv4Addr, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, ipv6Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, ipv6Addr, err) + } - return e, s, clock + return s, clock +} + +// checkInitialIPv6Groups checks the initial IPv6 groups that a NIC will join +// when it is created with an IPv6 address. +// +// To not interfere with tests, checkInitialIPv6Groups will leave the added +// address's solicited node multicast group so that the tests can all assume +// the NIC has not joined any IPv6 groups. +func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, clock *faketime.ManualClock) (reportCounter uint64, leaveCounter uint64) { + t.Helper() + + stats := s.Stats().ICMP.V6.PacketsSent + + reportCounter++ + if got := stats.MulticastListenerReport.Value(); got != reportCounter { + t.Errorf("got stats.MulticastListenerReport.Value() = %d, want = %d", got, reportCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, ipv6AddrSNMC, mldReport, 0, ipv6AddrSNMC) + } + + // Leave the group to not affect the tests. This is fine since we are not + // testing DAD or the solicited node address specifically. + if err := s.LeaveGroup(ipv6.ProtocolNumber, nicID, ipv6AddrSNMC); err != nil { + t.Fatalf("LeaveGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, ipv6AddrSNMC, err) + } + leaveCounter++ + if got := stats.MulticastListenerDone.Value(); got != leaveCounter { + t.Errorf("got stats.MulticastListenerDone.Value() = %d, want = %d", got, leaveCounter) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + } + + // Should not send any more packets. + clock.Advance(time.Hour) + if p, ok := e.Read(); ok { + t.Fatalf("sent unexpected packet = %#v", p) + } + + return reportCounter, leaveCounter } // createAndInjectIGMPPacket creates and injects an IGMP packet with the @@ -170,11 +242,11 @@ func createAndInjectMLDPacket(e *channel.Endpoint, mldType uint8, maxRespDelay b ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(icmpSize), - HopLimit: header.MLDHopLimit, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - SrcAddr: header.IPv4Any, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(icmpSize), + HopLimit: header.MLDHopLimit, + TransportProtocol: header.ICMPv6ProtocolNumber, + SrcAddr: header.IPv4Any, + DstAddr: header.IPv6AllNodesMulticastAddress, }) icmp := header.ICMPv6(buf[header.IPv6MinimumSize:]) @@ -232,13 +304,13 @@ func TestMGPDisabled(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, false) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, false /* mgpEnabled */) // This NIC may join multicast groups when it is enabled but since MGP is // disabled, no reports should be sent. sentReportStat := test.sentReportStat(s) if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -251,7 +323,7 @@ func TestMGPDisabled(t *testing.T) { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } if got := sentReportStat.Value(); got != 0 { - t.Fatalf("got sentReportState.Value() = %d, want = 0", got) + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) } clock.Advance(time.Hour) if p, ok := e.Read(); ok { @@ -355,7 +427,7 @@ func TestMGPReceiveCounters(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, _ := createStack(t, true) + e, s, _ := createStack(t, len(test.groupAddress) == header.IPv4AddressSize /* v4 */, true /* mgpEnabled */) test.rxMGPkt(e, test.headerType, test.maxRespTime, test.groupAddress) if got := test.statCounter(s).Value(); got != 1 { @@ -376,6 +448,7 @@ func TestMGPJoinGroup(t *testing.T) { sentReportStat func(*stack.Stack) *tcpip.StatCounter receivedQueryStat func(*stack.Stack) *tcpip.StatCounter validateReport func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -410,21 +483,28 @@ func TestMGPJoinGroup(t *testing.T) { validateMLDPacket(t, p, ipv6MulticastAddr1, mldReport, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } // Test joining a specific address explicitly and verify a Report is sent // immediately. if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } + reportCounter++ sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportState.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -442,8 +522,9 @@ func TestMGPJoinGroup(t *testing.T) { t.Fatalf("sent unexpected packet, expected report only after advancing the clock = %#v", p.Pkt) } clock.Advance(test.maxUnsolicitedResponseDelay) - if got := sentReportStat.Value(); got != 2 { - t.Errorf("got sentReportState.Value() = %d, want = 2", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -464,13 +545,14 @@ func TestMGPJoinGroup(t *testing.T) { // group the stack sends a leave/done message. func TestMGPLeaveGroup(t *testing.T) { tests := []struct { - name string - protoNum tcpip.NetworkProtocolNumber - multicastAddr tcpip.Address - sentReportStat func(*stack.Stack) *tcpip.StatCounter - sentLeaveStat func(*stack.Stack) *tcpip.StatCounter - validateReport func(*testing.T, channel.PacketInfo) - validateLeave func(*testing.T, channel.PacketInfo) + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + sentLeaveStat func(*stack.Stack) *tcpip.StatCounter + validateReport func(*testing.T, channel.PacketInfo) + validateLeave func(*testing.T, channel.PacketInfo) + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -513,18 +595,26 @@ func TestMGPLeaveGroup(t *testing.T) { validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } - if got := test.sentReportStat(s).Value(); got != 1 { - t.Errorf("got sentReportStat(_).Value() = %d, want = 1", got) + reportCounter++ + if got := test.sentReportStat(s).Value(); got != reportCounter { + t.Errorf("got sentReportStat(_).Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -539,8 +629,9 @@ func TestMGPLeaveGroup(t *testing.T) { if err := s.LeaveGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } - if got := test.sentLeaveStat(s).Value(); got != 1 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 1", got) + leaveCounter++ + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a leave message to be sent") @@ -570,6 +661,7 @@ func TestMGPQueryMessages(t *testing.T) { rxQuery func(*channel.Endpoint, uint8, tcpip.Address) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -614,6 +706,7 @@ func TestMGPQueryMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } @@ -647,16 +740,22 @@ func TestMGPQueryMessages(t *testing.T) { for _, subTest := range subTests { t.Run(subTest.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, _ = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - for i := uint64(1); i <= 2; i++ { + for i := 0; i < maxUnsolicitedReports; i++ { sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != i { - t.Errorf("(i=%d) got sentReportState.Value() = %d, want = %d", i, got, i) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("(i=%d) got sentReportStat.Value() = %d, want = %d", i, got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatalf("expected %d-th report message to be sent", i) @@ -686,8 +785,9 @@ func TestMGPQueryMessages(t *testing.T) { if subTest.expectReport { clock.Advance(test.maxRespTimeToDuration(maxRespTime)) - if got := sentReportStat.Value(); got != 3 { - t.Errorf("got sentReportState.Value() = %d, want = 3", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -719,6 +819,7 @@ func TestMGPReportMessages(t *testing.T) { rxReport func(*channel.Endpoint) validateReport func(*testing.T, channel.PacketInfo) maxRespTimeToDuration func(uint8) time.Duration + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -761,19 +862,27 @@ func TestMGPReportMessages(t *testing.T) { maxRespTimeToDuration: func(d uint8) time.Duration { return time.Duration(d) * time.Millisecond }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) + + var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) } sentReportStat := test.sentReportStat(s) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + reportCounter++ + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -788,8 +897,8 @@ func TestMGPReportMessages(t *testing.T) { // reports. test.rxReport(e) clock.Advance(time.Hour) - if got := sentReportStat.Value(); got != 1 { - t.Errorf("got sentReportStat.Value() = %d, want = 1", got) + if got := sentReportStat.Value(); got != reportCounter { + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); ok { t.Errorf("sent unexpected packet = %#v", p) @@ -804,8 +913,8 @@ func TestMGPReportMessages(t *testing.T) { t.Fatalf("LeaveGroup(%d, nic, %s): %s", test.protoNum, test.multicastAddr, err) } clock.Advance(time.Hour) - if got := test.sentLeaveStat(s).Value(); got != 0 { - t.Fatalf("got sentLeaveStat(_).Value() = %d, want = 0", got) + if got := test.sentLeaveStat(s).Value(); got != leaveCounter { + t.Fatalf("got sentLeaveStat(_).Value() = %d, want = %d", got, leaveCounter) } // Should not send any more packets. @@ -829,6 +938,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateReport func(*testing.T, channel.PacketInfo, tcpip.Address) validateLeave func(*testing.T, channel.PacketInfo, tcpip.Address) getAndCheckGroupAddress func(*testing.T, map[tcpip.Address]bool, channel.PacketInfo) tcpip.Address + checkInitialGroups func(*testing.T, *channel.Endpoint, *stack.Stack, *faketime.ManualClock) (uint64, uint64) }{ { name: "IGMP", @@ -897,10 +1007,31 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Helper() ipv6 := header.IPv6(stack.PayloadSince(p.Pkt.NetworkHeader())) - if got := tcpip.TransportProtocolNumber(ipv6.NextHeader()); got != header.ICMPv6ProtocolNumber { + + ipv6HeaderIter := header.MakeIPv6PayloadIterator( + header.IPv6ExtensionHeaderIdentifier(ipv6.NextHeader()), + buffer.View(ipv6.Payload()).ToVectorisedView(), + ) + + var transport header.IPv6RawPayloadHeader + for { + h, done, err := ipv6HeaderIter.Next() + if err != nil { + t.Fatalf("ipv6HeaderIter.Next(): %s", err) + } + if done { + t.Fatalf("ipv6HeaderIter.Next() = (%T, %t, _), want = (_, false, _)", h, done) + } + if t, ok := h.(header.IPv6RawPayloadHeader); ok { + transport = t + break + } + } + + if got := tcpip.TransportProtocolNumber(transport.Identifier); got != header.ICMPv6ProtocolNumber { t.Fatalf("got ipv6.NextHeader() = %d, want = %d", got, header.ICMPv6ProtocolNumber) } - icmpv6 := header.ICMPv6(ipv6.Payload()) + icmpv6 := header.ICMPv6(transport.Buf.ToView()) if got := icmpv6.Type(); got != header.ICMPv6MulticastListenerReport && got != header.ICMPv6MulticastListenerDone { t.Fatalf("got icmpv6.Type() = %d, want = %d or %d", got, header.ICMPv6MulticastListenerReport, header.ICMPv6MulticastListenerDone) } @@ -914,17 +1045,22 @@ func TestMGPWithNICLifecycle(t *testing.T) { } seen[addr] = true return addr - }, + checkInitialGroups: checkInitialIPv6Groups, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e, s, clock := createStack(t, true) + e, s, clock := createStack(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */) - sentReportStat := test.sentReportStat(s) var reportCounter uint64 + var leaveCounter uint64 + if test.checkInitialGroups != nil { + reportCounter, leaveCounter = test.checkInitialGroups(t, e, s, clock) + } + + sentReportStat := test.sentReportStat(s) for _, a := range test.multicastAddrs { if err := s.JoinGroup(test.protoNum, nicID, a); err != nil { t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, a, err) @@ -949,7 +1085,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { t.Fatalf("DisableNIC(%d): %s", nicID, err) } sentLeaveStat := test.sentLeaveStat(s) - leaveCounter := uint64(len(test.multicastAddrs)) + leaveCounter += uint64(len(test.multicastAddrs)) if got := sentLeaveStat.Value(); got != leaveCounter { t.Errorf("got sentLeaveStat.Value() = %d, want = %d", got, leaveCounter) } @@ -1051,7 +1187,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { clock.Advance(test.maxUnsolicitedResponseDelay) reportCounter++ if got := sentReportStat.Value(); got != reportCounter { - t.Errorf("got sentReportState.Value() = %d, want = %d", got, reportCounter) + t.Errorf("got sentReportStat.Value() = %d, want = %d", got, reportCounter) } if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") @@ -1067,3 +1203,59 @@ func TestMGPWithNICLifecycle(t *testing.T) { }) } } + +// TestMGPDisabledOnLoopback tests that the multicast group protocol is not +// performed on loopback interfaces since they have no neighbours. +func TestMGPDisabledOnLoopback(t *testing.T) { + tests := []struct { + name string + protoNum tcpip.NetworkProtocolNumber + multicastAddr tcpip.Address + sentReportStat func(*stack.Stack) *tcpip.StatCounter + }{ + { + name: "IGMP", + protoNum: ipv4.ProtocolNumber, + multicastAddr: ipv4MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().IGMP.PacketsSent.V2MembershipReport + }, + }, + { + name: "MLD", + protoNum: ipv6.ProtocolNumber, + multicastAddr: ipv6MulticastAddr1, + sentReportStat: func(s *stack.Stack) *tcpip.StatCounter { + return s.Stats().ICMP.V6.PacketsSent.MulticastListenerReport + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, clock := createStackWithLinkEndpoint(t, test.protoNum == ipv4.ProtocolNumber /* v4 */, true /* mgpEnabled */, loopback.New()) + + sentReportStat := test.sentReportStat(s) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + + // Test joining a specific group explicitly and verify that no reports are + // sent. + if err := s.JoinGroup(test.protoNum, nicID, test.multicastAddr); err != nil { + t.Fatalf("JoinGroup(%d, %d, %s): %s", test.protoNum, nicID, test.multicastAddr, err) + } + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + clock.Advance(time.Hour) + if got := sentReportStat.Value(); got != 0 { + t.Fatalf("got sentReportStat.Value() = %d, want = 0", got) + } + }) + } +} diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go index c53698a6a..f3ad40fdf 100644 --- a/pkg/tcpip/socketops.go +++ b/pkg/tcpip/socketops.go @@ -16,6 +16,8 @@ package tcpip import ( "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" ) // SocketOptionsHandler holds methods that help define endpoint specific @@ -37,6 +39,15 @@ type SocketOptionsHandler interface { // OnCorkOptionSet is invoked when TCP_CORK is set for an endpoint. OnCorkOptionSet(v bool) + + // LastError is invoked when SO_ERROR is read for an endpoint. + LastError() *Error + + // UpdateLastError updates the endpoint specific last error field. + UpdateLastError(err *Error) + + // HasNIC is invoked to check if the NIC is valid for SO_BINDTODEVICE. + HasNIC(v int32) bool } // DefaultSocketOptionsHandler is an embeddable type that implements no-op @@ -60,6 +71,19 @@ func (*DefaultSocketOptionsHandler) OnDelayOptionSet(bool) {} // OnCorkOptionSet implements SocketOptionsHandler.OnCorkOptionSet. func (*DefaultSocketOptionsHandler) OnCorkOptionSet(bool) {} +// LastError implements SocketOptionsHandler.LastError. +func (*DefaultSocketOptionsHandler) LastError() *Error { + return nil +} + +// UpdateLastError implements SocketOptionsHandler.UpdateLastError. +func (*DefaultSocketOptionsHandler) UpdateLastError(*Error) {} + +// HasNIC implements SocketOptionsHandler.HasNIC. +func (*DefaultSocketOptionsHandler) HasNIC(int32) bool { + return false +} + // SocketOptions contains all the variables which store values for SOL_SOCKET, // SOL_IP, SOL_IPV6 and SOL_TCP level options. // @@ -69,24 +93,24 @@ type SocketOptions struct { // These fields are accessed and modified using atomic operations. - // broadcastEnabled determines whether datagram sockets are allowed to send - // packets to a broadcast address. + // broadcastEnabled determines whether datagram sockets are allowed to + // send packets to a broadcast address. broadcastEnabled uint32 - // passCredEnabled determines whether SCM_CREDENTIALS socket control messages - // are enabled. + // passCredEnabled determines whether SCM_CREDENTIALS socket control + // messages are enabled. passCredEnabled uint32 // noChecksumEnabled determines whether UDP checksum is disabled while // transmitting for this socket. noChecksumEnabled uint32 - // reuseAddressEnabled determines whether Bind() should allow reuse of local - // address. + // reuseAddressEnabled determines whether Bind() should allow reuse of + // local address. reuseAddressEnabled uint32 - // reusePortEnabled determines whether to permit multiple sockets to be bound - // to an identical socket address. + // reusePortEnabled determines whether to permit multiple sockets to be + // bound to an identical socket address. reusePortEnabled uint32 // keepAliveEnabled determines whether TCP keepalive is enabled for this @@ -94,7 +118,7 @@ type SocketOptions struct { keepAliveEnabled uint32 // multicastLoopEnabled determines whether multicast packets sent over a - // non-loopback interface will be looped back. Analogous to inet->mc_loop. + // non-loopback interface will be looped back. multicastLoopEnabled uint32 // receiveTOSEnabled is used to specify if the TOS ancillary message is @@ -130,6 +154,28 @@ type SocketOptions struct { // corkOptionEnabled is used to specify if data should be held until segments // are full by the TCP transport protocol. corkOptionEnabled uint32 + + // receiveOriginalDstAddress is used to specify if the original destination of + // the incoming packet should be returned as an ancillary message. + receiveOriginalDstAddress uint32 + + // recvErrEnabled determines whether extended reliable error message passing + // is enabled. + recvErrEnabled uint32 + + // errQueue is the per-socket error queue. It is protected by errQueueMu. + errQueueMu sync.Mutex `state:"nosave"` + errQueue sockErrorList + + // bindToDevice determines the device to which the socket is bound. + bindToDevice int32 + + // mu protects the access to the below fields. + mu sync.Mutex `state:"nosave"` + + // linger determines the amount of time the socket should linger before + // close. We currently implement this option for TCP socket only. + linger LingerOption } // InitHandler initializes the handler. This must be called before using the @@ -146,6 +192,11 @@ func storeAtomicBool(addr *uint32, v bool) { atomic.StoreUint32(addr, val) } +// SetLastError sets the last error for a socket. +func (so *SocketOptions) SetLastError(err *Error) { + so.handler.UpdateLastError(err) +} + // GetBroadcast gets value for SO_BROADCAST option. func (so *SocketOptions) GetBroadcast() bool { return atomic.LoadUint32(&so.broadcastEnabled) != 0 @@ -302,3 +353,168 @@ func (so *SocketOptions) SetCorkOption(v bool) { storeAtomicBool(&so.corkOptionEnabled, v) so.handler.OnCorkOptionSet(v) } + +// GetReceiveOriginalDstAddress gets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) GetReceiveOriginalDstAddress() bool { + return atomic.LoadUint32(&so.receiveOriginalDstAddress) != 0 +} + +// SetReceiveOriginalDstAddress sets value for IP(V6)_RECVORIGDSTADDR option. +func (so *SocketOptions) SetReceiveOriginalDstAddress(v bool) { + storeAtomicBool(&so.receiveOriginalDstAddress, v) +} + +// GetRecvError gets value for IP*_RECVERR option. +func (so *SocketOptions) GetRecvError() bool { + return atomic.LoadUint32(&so.recvErrEnabled) != 0 +} + +// SetRecvError sets value for IP*_RECVERR option. +func (so *SocketOptions) SetRecvError(v bool) { + storeAtomicBool(&so.recvErrEnabled, v) + if !v { + so.pruneErrQueue() + } +} + +// GetLastError gets value for SO_ERROR option. +func (so *SocketOptions) GetLastError() *Error { + return so.handler.LastError() +} + +// GetOutOfBandInline gets value for SO_OOBINLINE option. +func (*SocketOptions) GetOutOfBandInline() bool { + return true +} + +// SetOutOfBandInline sets value for SO_OOBINLINE option. We currently do not +// support disabling this option. +func (*SocketOptions) SetOutOfBandInline(bool) {} + +// GetLinger gets value for SO_LINGER option. +func (so *SocketOptions) GetLinger() LingerOption { + so.mu.Lock() + linger := so.linger + so.mu.Unlock() + return linger +} + +// SetLinger sets value for SO_LINGER option. +func (so *SocketOptions) SetLinger(linger LingerOption) { + so.mu.Lock() + so.linger = linger + so.mu.Unlock() +} + +// SockErrOrigin represents the constants for error origin. +type SockErrOrigin uint8 + +const ( + // SockExtErrorOriginNone represents an unknown error origin. + SockExtErrorOriginNone SockErrOrigin = iota + + // SockExtErrorOriginLocal indicates a local error. + SockExtErrorOriginLocal + + // SockExtErrorOriginICMP indicates an IPv4 ICMP error. + SockExtErrorOriginICMP + + // SockExtErrorOriginICMP6 indicates an IPv6 ICMP error. + SockExtErrorOriginICMP6 +) + +// IsICMPErr indicates if the error originated from an ICMP error. +func (origin SockErrOrigin) IsICMPErr() bool { + return origin == SockExtErrorOriginICMP || origin == SockExtErrorOriginICMP6 +} + +// SockError represents a queue entry in the per-socket error queue. +// +// +stateify savable +type SockError struct { + sockErrorEntry + + // Err is the error caused by the errant packet. + Err *Error + // ErrOrigin indicates the error origin. + ErrOrigin SockErrOrigin + // ErrType is the type in the ICMP header. + ErrType uint8 + // ErrCode is the code in the ICMP header. + ErrCode uint8 + // ErrInfo is additional info about the error. + ErrInfo uint32 + + // Payload is the errant packet's payload. + Payload []byte + // Dst is the original destination address of the errant packet. + Dst FullAddress + // Offender is the original sender address of the errant packet. + Offender FullAddress + // NetProto is the network protocol being used to transmit the packet. + NetProto NetworkProtocolNumber +} + +// pruneErrQueue resets the queue. +func (so *SocketOptions) pruneErrQueue() { + so.errQueueMu.Lock() + so.errQueue.Reset() + so.errQueueMu.Unlock() +} + +// DequeueErr dequeues a socket extended error from the error queue and returns +// it. Returns nil if queue is empty. +func (so *SocketOptions) DequeueErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + + err := so.errQueue.Front() + if err != nil { + so.errQueue.Remove(err) + } + return err +} + +// PeekErr returns the error in the front of the error queue. Returns nil if +// the error queue is empty. +func (so *SocketOptions) PeekErr() *SockError { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + return so.errQueue.Front() +} + +// QueueErr inserts the error at the back of the error queue. +// +// Preconditions: so.GetRecvError() == true. +func (so *SocketOptions) QueueErr(err *SockError) { + so.errQueueMu.Lock() + defer so.errQueueMu.Unlock() + so.errQueue.PushBack(err) +} + +// QueueLocalErr queues a local error onto the local queue. +func (so *SocketOptions) QueueLocalErr(err *Error, net NetworkProtocolNumber, info uint32, dst FullAddress, payload []byte) { + so.QueueErr(&SockError{ + Err: err, + ErrOrigin: SockExtErrorOriginLocal, + ErrInfo: info, + Payload: payload, + Dst: dst, + NetProto: net, + }) +} + +// GetBindToDevice gets value for SO_BINDTODEVICE option. +func (so *SocketOptions) GetBindToDevice() int32 { + return atomic.LoadInt32(&so.bindToDevice) +} + +// SetBindToDevice sets value for SO_BINDTODEVICE option. +func (so *SocketOptions) SetBindToDevice(bindToDevice int32) *Error { + if !so.handler.HasNIC(bindToDevice) { + return ErrUnknownDevice + } + + atomic.StoreInt32(&so.bindToDevice, bindToDevice) + return nil +} diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9cc6074da..bb30556cf 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -148,7 +148,6 @@ go_test( ], library = ":stack", deps = [ - "//pkg/sleep", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 6e4f5fa46..cd423bf71 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -82,12 +82,16 @@ func (a *AddressableEndpointState) ForEachEndpoint(f func(AddressEndpoint) bool) } // ForEachPrimaryEndpoint calls f for each primary address. -func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint)) { +// +// Once f returns false, f will no longer be called. +func (a *AddressableEndpointState) ForEachPrimaryEndpoint(f func(AddressEndpoint) bool) { a.mu.RLock() defer a.mu.RUnlock() for _, ep := range a.mu.primary { - f(ep) + if !f(ep) { + return + } } } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 5ec9b3411..93e8e1c51 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -560,6 +560,38 @@ func TestForwardingWithNoResolver(t *testing.T) { } } +func TestForwardingResolutionFailsForQueuedPackets(t *testing.T) { + proto := &fwdTestNetworkProtocol{ + addrResolveDelay: 50 * time.Millisecond, + onLinkAddressResolved: func(*linkAddrCache, *neighborCache, tcpip.Address, tcpip.LinkAddress) { + // Don't resolve the link address. + }, + } + + ep1, ep2 := fwdTestNetFactory(t, proto, true /* useNeighborCache */) + + const numPackets int = 5 + // These packets will all be enqueued in the packet queue to wait for link + // address resolution. + for i := 0; i < numPackets; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } + + // All packets should fail resolution. + // TODO(gvisor.dev/issue/5141): Use a fake clock. + for i := 0; i < numPackets; i++ { + select { + case got := <-ep2.C: + t.Fatalf("got %#v; packets should have failed resolution and not been forwarded", got) + case <-time.After(100 * time.Millisecond): + } + } +} + func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { tests := []struct { name string diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index c9b13cd0e..792f4f170 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -18,7 +18,6 @@ import ( "fmt" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -58,9 +57,6 @@ const ( incomplete entryState = iota // ready means that the address has been resolved and can be used. ready - // failed means that address resolution timed out and the address - // could not be resolved. - failed ) // String implements Stringer. @@ -70,8 +66,6 @@ func (s entryState) String() string { return "incomplete" case ready: return "ready" - case failed: - return "failed" default: return fmt.Sprintf("unknown(%d)", s) } @@ -80,40 +74,48 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + // linkAddrEntryEntry access is synchronized by the linkAddrCache lock. linkAddrEntryEntry + // TODO(gvisor.dev/issue/5150): move these fields under mu. + // mu protects the fields below. + mu sync.RWMutex + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState - // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of incomplete these waiters are notified. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil iff - // s is incomplete and resolution is not yet in progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) } -// changeState sets the entry's state to ns, notifying any waiters. +func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + for _, callback := range e.onResolve { + callback(linkAddr, len(linkAddr) != 0) + } + e.onResolve = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// changeStateLocked sets the entry's state to ns. // // The entry's expiration is bumped up to the greater of itself and the passed // expiration; the zero value indicates immediate expiration, and is set // unconditionally - this is an implementation detail that allows for entries // to be reused. -func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { - // Notify whoever is waiting on address resolution when transitioning - // out of incomplete. - if e.s == incomplete && ns != incomplete { - for w := range e.wakers { - w.Assert() - } - e.wakers = nil - if ch := e.done; ch != nil { - close(ch) - } - e.done = nil +// +// Precondition: e.mu must be locked +func (e *linkAddrEntry) changeStateLocked(ns entryState, expiration time.Time) { + if e.s == incomplete && ns == ready { + e.notifyCompletionLocked(e.linkAddr) } if expiration.IsZero() || expiration.After(e.expiration) { @@ -122,10 +124,6 @@ func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { e.s = ns } -func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { - delete(e.wakers, w) -} - // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { // Calculate expiration time before acquiring the lock, since expiration is @@ -135,10 +133,12 @@ func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { c.cache.Lock() entry := c.getOrCreateEntryLocked(k) - entry.linkAddr = v - - entry.changeState(ready, expiration) c.cache.Unlock() + + entry.mu.Lock() + defer entry.mu.Unlock() + entry.linkAddr = v + entry.changeStateLocked(ready, expiration) } // getOrCreateEntryLocked retrieves a cache entry associated with k. The @@ -159,13 +159,14 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt var entry *linkAddrEntry if len(c.cache.table) == linkAddrCacheSize { entry = c.cache.lru.Back() + entry.mu.Lock() delete(c.cache.table, entry.addr) c.cache.lru.Remove(entry) - // Wake waiters and mark the soon-to-be-reused entry as expired. Note - // that the state passed doesn't matter when the zero time is passed. - entry.changeState(failed, time.Time{}) + // Wake waiters and mark the soon-to-be-reused entry as expired. + entry.notifyCompletionLocked("" /* linkAddr */) + entry.mu.Unlock() } else { entry = new(linkAddrEntry) } @@ -180,9 +181,12 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + if onResolve != nil { + onResolve(addr, true) + } return addr, nil, nil } } @@ -190,56 +194,35 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: + case ready: if !time.Now().After(entry.expiration) { // Not expired. - switch s { - case ready: - return entry.linkAddr, nil, nil - case failed: - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + if onResolve != nil { + onResolve(entry.linkAddr, true) } + return entry.linkAddr, nil, nil } - entry.changeState(incomplete, time.Time{}) + entry.changeStateLocked(incomplete, time.Time{}) fallthrough case incomplete: - if waker != nil { - if entry.wakers == nil { - entry.wakers = make(map[*sleep.Waker]struct{}) - } - entry.wakers[waker] = struct{}{} + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) } - if entry.done == nil { - // Address resolution needs to be initiated. - if linkRes == nil { - return entry.linkAddr, nil, tcpip.ErrNoLinkAddress - } - entry.done = make(chan struct{}) go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } - return entry.linkAddr, entry.done, tcpip.ErrWouldBlock default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } } -// removeWaker removes a waker previously added through get(). -func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.cache.Lock() - defer c.cache.Unlock() - - if entry, ok := c.cache.table[k]; ok { - entry.removeWaker(waker) - } -} - func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check @@ -257,9 +240,9 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link } } -// checkLinkRequest checks whether previous attempt to resolve address has succeeded -// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request -// can stop, false if another request should be sent. +// checkLinkRequest checks whether previous attempt to resolve address has +// succeeded and mark the entry accordingly. Returns true if request can stop, +// false if another request should be sent. func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { c.cache.Lock() defer c.cache.Unlock() @@ -268,16 +251,20 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, att // Entry was evicted from the cache. return true } + entry.mu.Lock() + defer entry.mu.Unlock() + switch s := entry.s; s { - case ready, failed: - // Entry was made ready by resolver or failed. Either way we're done. + case ready: + // Entry was made ready by resolver. case incomplete: if attempt+1 < c.resolutionAttempts { // No response yet, need to send another ARP request. return false } - // Max number of retries reached, mark entry as failed. - entry.changeState(failed, now.Add(c.ageLimit)) + // Max number of retries reached, delete entry. + entry.notifyCompletionLocked("" /* linkAddr */) + delete(c.cache.table, k) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index d2e37f38d..6883045b5 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -21,7 +21,6 @@ import ( "testing" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -50,6 +49,7 @@ type testLinkAddressResolver struct { } func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + // TODO(gvisor.dev/issue/5141): Use a fake clock. time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() @@ -78,16 +78,18 @@ func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumbe } func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - + var attemptedResolution bool for { - if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - return got, err + got, ch, err := c.get(addr, linkRes, "", nil, nil) + if err == tcpip.ErrWouldBlock { + if attemptedResolution { + return got, tcpip.ErrNoLinkAddress + } + attemptedResolution = true + <-ch + continue } - s.Fetch(true) + return got, err } } @@ -116,16 +118,19 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. + c.cache.Lock() + defer c.cache.Unlock() for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { e := testAddrs[i] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } } func TestCacheConcurrent(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} var wg sync.WaitGroup for r := 0; r < 16; r++ { @@ -133,7 +138,6 @@ func TestCacheConcurrent(t *testing.T) { go func() { for _, e := range testAddrs { c.add(e.addr, e.linkAddr) - c.get(e.addr, nil, "", nil, nil) // make work for gotsan } wg.Done() }() @@ -144,7 +148,7 @@ func TestCacheConcurrent(t *testing.T) { // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. e := testAddrs[len(testAddrs)-1] - got, _, err := c.get(e.addr, nil, "", nil, nil) + got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) } @@ -153,18 +157,22 @@ func TestCacheConcurrent(t *testing.T) { } e = testAddrs[0] - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + c.cache.Lock() + defer c.cache.Unlock() + if entry, ok := c.cache.table[e.addr]; ok { + t.Errorf("unexpected entry at c.cache.table[%q]: %#v", string(e.addr.Addr), entry) } } func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + linkRes := &testLinkAddressResolver{cache: c} + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) - if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { - t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + if _, _, err := c.get(e.addr, linkRes, "", nil, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.get(%q) = %s, want = ErrWouldBlock", string(e.addr.Addr), err) } } @@ -282,71 +290,3 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } - -// TestCacheWaker verifies that RemoveWaker removes a waker previously added -// through get(). -func TestCacheWaker(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - - // First, sanity check that wakers are working. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 1 - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[0] - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - id, ok := s.Fetch(true /* block */) - if !ok { - t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") - } - if id != wakerID { - t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) - } - - if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - } - - // Check that RemoveWaker works. - { - linkRes := &testLinkAddressResolver{cache: c} - s := sleep.Sleeper{} - defer s.Done() - - const wakerID = 2 // different than the ID used in the sanity check - w := sleep.Waker{} - s.AddWaker(&w, wakerID) - - e := testAddrs[1] - linkRes.onLinkAddressRequest = func() { - // Remove the waker before the linkAddrCache has the opportunity to send - // a notification. - c.removeWaker(e.addr, &w) - } - - if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) - } - - if got, err := getBlocking(c, e.addr, linkRes); err != nil { - t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) - } else if got != e.linkAddr { - t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Fatalf("unexpected notification from waker with id %d", id) - } - } -} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 31b67b987..61636cae5 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -540,8 +540,8 @@ func TestDADResolve(t *testing.T) { // Make sure the right remote link address is used. snmc := header.SolicitedNodeAddr(addr1) - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } // Check NDP NS packet. @@ -577,11 +577,11 @@ func rxNDPSolicit(e *channel.Endpoint, tgt tcpip.Address) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: header.IPv6Any, - DstAddr: snmc, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: header.IPv6Any, + DstAddr: snmc, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) } @@ -623,11 +623,11 @@ func TestDADFail(t *testing.T) { payloadLength := hdr.UsedLength() ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: 255, - SrcAddr: tgt, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: 255, + SrcAddr: tgt, + DstAddr: header.IPv6AllNodesMulticastAddress, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{Data: hdr.View().ToVectorisedView()})) }, @@ -1011,11 +1011,11 @@ func raBufWithOptsAndDHCPv6(ip tcpip.Address, rl uint16, managedAddress, otherCo payloadLength := hdr.UsedLength() iph := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) iph.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: header.NDPHopLimit, - SrcAddr: ip, - DstAddr: header.IPv6AllNodesMulticastAddress, + PayloadLength: uint16(payloadLength), + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: header.NDPHopLimit, + SrcAddr: ip, + DstAddr: header.IPv6AllNodesMulticastAddress, }) return stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -5197,8 +5197,8 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); got != want { - t.Errorf("got remote link address = %s, want = %s", got, want) + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 317f6871d..c15f10e76 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -99,9 +98,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA n.dynamic.lru.Remove(e) n.dynamic.count-- - e.dispatchRemoveEventLocked() - e.setStateLocked(Unknown) - e.notifyWakersLocked() + e.removeLocked() e.mu.Unlock() } n.cache[remoteAddr] = entry @@ -110,21 +107,27 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA return entry } -// entry looks up the neighbor cache for translating address to link address -// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there -// is a LinkAddressResolver registered with the network protocol, the cache -// attempts to resolve the address and returns ErrWouldBlock. If a Waker is -// provided, it will be notified when address resolution is complete (success -// or not). +// entry looks up neighbor information matching the remote address, and returns +// it if readily available. +// +// Returns ErrWouldBlock if the link address is not readily available, along +// with a notification channel for the caller to block on. Triggers address +// resolution asynchronously. +// +// If onResolve is provided, it will be called either immediately, if resolution +// is not required, or when address resolution is complete, with the resolved +// link address and whether resolution succeeded. After any callbacks have been +// called, the returned notification channel is closed. +// +// NB: if a callback is provided, it should not call into the neighbor cache. // // If specified, the local address must be an address local to the interface the // neighbor cache belongs to. The local address is the source address of a // packet prompting NUD/link address resolution. // -// If address resolution is required, ErrNoLinkAddress and a notification -// channel is returned for the top level caller to block. Channel is closed -// once address resolution is complete (success or not). -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + // TODO(gvisor.dev/issue/5149): Handle static resolution in route.Resolve. if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ Addr: remoteAddr, @@ -132,6 +135,9 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA State: Static, UpdatedAtNanos: 0, } + if onResolve != nil { + onResolve(linkAddr, true) + } return e, nil, nil } @@ -149,37 +155,25 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // of packets to a neighbor. While reasserting a neighbor's reachability, // a node continues sending packets to that neighbor using the cached // link-layer address." + if onResolve != nil { + onResolve(entry.neigh.LinkAddr, true) + } return entry.neigh, nil, nil - case Unknown, Incomplete: - entry.addWakerLocked(w) - + case Unknown, Incomplete, Failed: + if onResolve != nil { + entry.onResolve = append(entry.onResolve, onResolve) + } if entry.done == nil { // Address resolution needs to be initiated. - if linkRes == nil { - return entry.neigh, nil, tcpip.ErrNoLinkAddress - } entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock - case Failed: - return entry.neigh, nil, tcpip.ErrNoLinkAddress default: panic(fmt.Sprintf("Invalid cache entry state: %s", s)) } } -// removeWaker removes a waker that has been added when link resolution for -// addr was requested. -func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { - n.mu.Lock() - if entry, ok := n.cache[addr]; ok { - delete(entry.wakers, waker) - } - n.mu.Unlock() -} - // entries returns all entries in the neighbor cache. func (n *neighborCache) entries() []NeighborEntry { n.mu.RLock() @@ -222,34 +216,13 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd return } - // Notify that resolution has been interrupted, just in case the entry was - // in the Incomplete or Probe state. - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } -// removeEntryLocked removes the specified entry from the neighbor cache. -// -// Prerequisite: n.mu and entry.mu MUST be locked. -func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { - if entry.neigh.State != Static { - n.dynamic.lru.Remove(entry) - n.dynamic.count-- - } - if entry.neigh.State != Failed { - entry.dispatchRemoveEventLocked() - } - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() - - delete(n.cache, entry.neigh.Addr) -} - // removeEntry removes a dynamic or static entry by address from the neighbor // cache. Returns true if the entry was found and deleted. func (n *neighborCache) removeEntry(addr tcpip.Address) bool { @@ -264,7 +237,13 @@ func (n *neighborCache) removeEntry(addr tcpip.Address) bool { entry.mu.Lock() defer entry.mu.Unlock() - n.removeEntryLocked(entry) + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + + entry.removeLocked() + delete(n.cache, entry.neigh.Addr) return true } @@ -275,9 +254,7 @@ func (n *neighborCache) clear() { for _, entry := range n.cache { entry.mu.Lock() - entry.dispatchRemoveEventLocked() - entry.setStateLocked(Unknown) - entry.notifyWakersLocked() + entry.removeLocked() entry.mu.Unlock() } diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 732a299f7..a2ed6ae2a 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" ) @@ -190,15 +189,18 @@ type testNeighborResolver struct { entries *testEntryStore delay time.Duration onLinkAddressRequest func() + dropReplies bool } var _ LinkAddressResolver = (*testNeighborResolver)(nil) func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { - // Delay handling the request to emulate network latency. - r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(targetAddr) - }) + if !r.dropReplies { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(targetAddr) + }) + } // Execute post address resolution action, if available. if f := r.onLinkAddressRequest; f != nil { @@ -291,10 +293,10 @@ func TestNeighborCacheEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -327,7 +329,7 @@ func TestNeighborCacheEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,11 +356,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -413,7 +415,7 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -461,7 +463,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -513,7 +515,7 @@ func (c *testContext) overflowCache(opts overflowOptions) error { } // Expect to find only the most recent entries. The order of entries reported - // by entries() is undeterministic, so entries have to be sorted before + // by entries() is nondeterministic, so entries have to be sorted before // comparison. wantUnsortedEntries := opts.wantStaticEntries for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { @@ -575,10 +577,10 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ @@ -650,7 +652,7 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -694,7 +696,7 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -756,7 +758,7 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { // Add a static entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } staticLinkAddr := entry.LinkAddr + "static" c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) @@ -826,10 +828,10 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -907,150 +909,6 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { } } -func TestNeighborCacheNotifiesWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - id, ok := s.Fetch(false /* block */) - if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) - } - if id != wakerID { - t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - -func TestNeighborCacheRemoveWaker(t *testing.T) { - config := DefaultNUDConfigurations() - - nudDisp := testNUDDispatcher{} - clock := faketime.NewManualClock() - neigh := newTestNeighborCache(&nudDisp, config, clock) - store := newTestEntryStore() - linkRes := &testNeighborResolver{ - clock: clock, - neigh: neigh, - entries: store, - delay: typicalLatency, - } - - w := sleep.Waker{} - s := sleep.Sleeper{} - const wakerID = 1 - s.AddWaker(&w, wakerID) - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) - if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) - } - - // Remove the waker before the neighbor cache has the opportunity to send a - // notification. - neigh.removeWaker(entry.Addr, &w) - clock.Advance(typicalLatency) - - select { - case <-doneCh: - default: - t.Fatal("expected notification from done channel") - } - - if id, ok := s.Fetch(false /* block */); ok { - t.Errorf("unexpected notification from waker with id %d", id) - } - - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: 1, - Entry: NeighborEntry{ - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, - }, - }, - } - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { config := DefaultNUDConfigurations() // Stay in Reachable so the cache can overflow @@ -1062,12 +920,12 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1075,7 +933,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ @@ -1129,10 +987,10 @@ func TestNeighborCacheClear(t *testing.T) { // Add a dynamic entry. entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1187,7 +1045,7 @@ func TestNeighborCacheClear(t *testing.T) { } } - // Clear shoud remove both dynamic and static entries. + // Clear should remove both dynamic and static entries. neigh.clear() // Remove events dispatched from clear() have no deterministic order so they @@ -1234,10 +1092,10 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { // Add a dynamic entry entry, ok := c.store.entry(0) if !ok { - t.Fatalf("c.store.entry(0) not found") + t.Fatal("c.store.entry(0) not found") } if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got c.neigh.entry(%s, '', _, nil, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ @@ -1318,7 +1176,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { frequentlyUsedEntry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } // The following logic is very similar to overflowCache, but @@ -1330,15 +1188,22 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { @@ -1373,7 +1238,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1381,15 +1246,23 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1435,7 +1308,7 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { } // Expect to find only the frequently used entry and the most recent entries. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { @@ -1494,12 +1367,12 @@ func TestNeighborCacheConcurrent(t *testing.T) { go func(entry NeighborEntry) { defer wg.Done() if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } - // Wait for all gorountines to send a request + // Wait for all goroutines to send a request wg.Wait() // Process all the requests for a single entry concurrently @@ -1509,7 +1382,7 @@ func TestNeighborCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than can fit in // the cache. Our eviction strategy requires that the last entries are // present, up to the size of the neighbor cache, and the rest are missing. - // The order of entries reported by entries() is undeterministic, so entries + // The order of entries reported by entries() is nondeterministic, so entries // have to be sorted before comparison. var wantUnsortedEntries []NeighborEntry for i := store.size() - neighborCacheSize; i < store.size(); i++ { @@ -1547,27 +1420,32 @@ func TestNeighborCacheReplace(t *testing.T) { // Add an entry entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { - case <-doneCh: + case <-ch: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) - } - if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) + t.Errorf("unexpected error from neigh.entry(%s, '', _, _, nil): %s", entry.Addr, err) } if t.Failed() { t.FailNow() @@ -1578,7 +1456,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1587,7 +1465,7 @@ func TestNeighborCacheReplace(t *testing.T) { { entry, ok := store.entry(1) if !ok { - t.Fatalf("store.entry(1) not found") + t.Fatal("store.entry(1) not found") } updatedLinkAddr = entry.LinkAddr } @@ -1604,7 +1482,7 @@ func TestNeighborCacheReplace(t *testing.T) { { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1612,7 +1490,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } @@ -1622,7 +1500,7 @@ func TestNeighborCacheReplace(t *testing.T) { e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1630,7 +1508,7 @@ func TestNeighborCacheReplace(t *testing.T) { State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1654,18 +1532,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { }, } - // First, sanity check that resolution is working entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + // First, sanity check that resolution is working + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + t.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + clock.Advance(typicalLatency) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } - clock.Advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1673,20 +1568,35 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } - // Verify that address resolution for an unknown address returns ErrNoLinkAddress + // Verify address resolution fails for an unknown address. before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) - } - waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) - clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } maxAttempts := neigh.config().MaxUnicastProbes @@ -1714,15 +1624,129 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { entry, ok := store.entry(0) if !ok { - t.Fatalf("store.entry(0) not found") + t.Fatal("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } +} + +// TestNeighborCacheRetryResolution simulates retrying communication after +// failing to perform address resolution. +func TestNeighborCacheRetryResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := faketime.NewManualClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + // Simulate a faulty link. + dropReplies: true, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatal("store.entry(0) not found") + } + + // Perform address resolution with a faulty link, which will fail. + { + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if ok { + t.Error("expected unsuccessful address resolution") + } + if len(linkAddr) != 0 { + t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) + } + if t.Failed() { + t.FailNow() + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.Advance(waitFor) + + select { + case <-ch: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } + } + + // Verify the entry is in Failed state. + wantEntries := []NeighborEntry{ + { + Addr: entry.Addr, + LinkAddr: "", + State: Failed, + }, + } + if diff := cmp.Diff(neigh.entries(), wantEntries, entryDiffOptsWithSort()...); diff != "" { + t.Fatalf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // Retry address resolution with a working link. + linkRes.dropReplies = false + { + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if linkAddr != entry.LinkAddr { + t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + } + if incompleteEntry.State != Incomplete { + t.Fatalf("got entry.State = %s, want = %s", incompleteEntry.State, Incomplete) + } + clock.Advance(typicalLatency) + + select { + case <-ch: + if !ok { + t.Fatal("expected successful address resolution") + } + reachableEntry, _, err := neigh.entry(entry.Addr, "", linkRes, nil) + if err != nil { + t.Fatalf("neigh.entry(%s, '', _, _, nil): %v", entry.Addr, err) + } + if reachableEntry.Addr != entry.Addr { + t.Fatalf("got entry.Addr = %s, want = %s", reachableEntry.Addr, entry.Addr) + } + if reachableEntry.LinkAddr != entry.LinkAddr { + t.Fatalf("got entry.LinkAddr = %s, want = %s", reachableEntry.LinkAddr, entry.LinkAddr) + } + if reachableEntry.State != Reachable { + t.Fatalf("got entry.State = %s, want = %s", reachableEntry.State.String(), Reachable.String()) + } + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) + } } } @@ -1742,7 +1766,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ Addr: testEntryBroadcastAddr, @@ -1750,7 +1774,7 @@ func TestNeighborCacheStaticResolution(t *testing.T) { State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1775,12 +1799,23 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) + + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { + if !ok { + b.Fatal("expected successful address resolution") + } + if linkAddr != entry.LinkAddr { + b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + } + }) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } - if doneCh != nil { - <-doneCh + + select { + case <-ch: + default: + b.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, _, nil)", entry.Addr) } } diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 32399b4f5..75afb3001 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -19,7 +19,6 @@ import ( "sync" "time" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,8 +66,7 @@ const ( // Static describes entries that have been explicitly added by the user. They // do not expire and are not deleted until explicitly removed. Static - // Failed means traffic should not be sent to this neighbor since attempts of - // reachability have returned inconclusive. + // Failed means recent attempts of reachability have returned inconclusive. Failed ) @@ -93,16 +91,13 @@ type neighborEntry struct { neigh NeighborEntry - // wakers is a set of waiters for address resolution result. Anytime state - // transitions out of incomplete these waiters are notified. It is nil iff - // address resolution is ongoing and no clients are waiting for the result. - wakers map[*sleep.Waker]struct{} - - // done is used to allow callers to wait on address resolution. It is nil - // iff nudState is not Reachable and address resolution is not yet in - // progress. + // done is closed when address resolution is complete. It is nil iff s is + // incomplete and resolution is not yet in progress. done chan struct{} + // onResolve is called with the result of address resolution. + onResolve []func(tcpip.LinkAddress, bool) + isRouter bool job *tcpip.Job } @@ -143,25 +138,15 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd } } -// addWaker adds w to the list of wakers waiting for address resolution. -// Assumes the entry has already been appropriately locked. -func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { - if w == nil { - return - } - if e.wakers == nil { - e.wakers = make(map[*sleep.Waker]struct{}) - } - e.wakers[w] = struct{}{} -} - -// notifyWakersLocked notifies those waiting for address resolution, whether it -// succeeded or failed. Assumes the entry has already been appropriately locked. -func (e *neighborEntry) notifyWakersLocked() { - for w := range e.wakers { - w.Assert() +// notifyCompletionLocked notifies those waiting for address resolution, with +// the link address if resolution completed successfully. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + for _, callback := range e.onResolve { + callback(e.neigh.LinkAddr, succeeded) } - e.wakers = nil + e.onResolve = nil if ch := e.done; ch != nil { close(ch) e.done = nil @@ -170,6 +155,8 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborAdded(e.nic.id, e.neigh) @@ -178,6 +165,8 @@ func (e *neighborEntry) dispatchAddEventLocked() { // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborChanged(e.nic.id, e.neigh) @@ -186,23 +175,41 @@ func (e *neighborEntry) dispatchChangeEventLocked() { // dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry // has been removed. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } +// cancelJobLocked cancels the currently scheduled action, if there is one. +// Entries in Unknown, Stale, or Static state do not have a scheduled action. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) cancelJobLocked() { + if job := e.job; job != nil { + job.Cancel() + } +} + +// removeLocked prepares the entry for removal. +// +// Precondition: e.mu MUST be locked. +func (e *neighborEntry) removeLocked() { + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + e.dispatchRemoveEventLocked() + e.cancelJobLocked() + e.notifyCompletionLocked(false /* succeeded */) +} + // setStateLocked transitions the entry to the specified state immediately. // // Follows the logic defined in RFC 4861 section 7.3.3. // -// e.mu MUST be locked. +// Precondition: e.mu MUST be locked. func (e *neighborEntry) setStateLocked(next NeighborState) { - // Cancel the previously scheduled action, if there is one. Entries in - // Unknown, Stale, or Static state do not have scheduled actions. - if timer := e.job; timer != nil { - timer.Cancel() - } + e.cancelJobLocked() prev := e.neigh.State e.neigh.State = next @@ -257,11 +264,7 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(immediateDuration) case Failed: - e.notifyWakersLocked() - e.job = e.nic.stack.newJob(&doubleLock{first: &e.nic.neigh.mu, second: &e.mu}, func() { - e.nic.neigh.removeEntryLocked(e) - }) - e.job.Schedule(config.UnreachableTime) + e.notifyCompletionLocked(false /* succeeded */) case Unknown, Stale, Static: // Do nothing @@ -275,8 +278,14 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { + case Failed: + e.nic.stats.Neighbor.FailedEntryLookups.Increment() + + fallthrough case Unknown: e.neigh.State = Incomplete e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() @@ -309,7 +318,7 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // implementation may find it convenient in some cases to return errors // to the sender by taking the offending packet, generating an ICMP // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 + // error-handling routines." - RFC 4861 section 2.1 e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return @@ -349,8 +358,6 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { case Incomplete, Reachable, Delay, Probe, Static: // Do nothing - case Failed: - e.nic.stats.Neighbor.FailedEntryLookups.Increment() default: panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } @@ -360,18 +367,30 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { // Neighbor Solicitation for ARP or NDP, respectively). // // Follows the logic defined in RFC 4861 section 7.2.3. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // Probes MUST be silently discarded if the target address is tentative, does // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These // checks MUST be done by the NetworkEndpoint. switch e.neigh.State { - case Unknown, Incomplete, Failed: + case Unknown, Failed: e.neigh.LinkAddr = remoteLinkAddr e.setStateLocked(Stale) - e.notifyWakersLocked() e.dispatchAddEventLocked() + case Incomplete: + // "If an entry already exists, and the cached link-layer address + // differs from the one in the received Source Link-Layer option, the + // cached address should be replaced by the received address, and the + // entry's reachability state MUST be set to STALE." + // - RFC 4861 section 7.2.3 + e.neigh.LinkAddr = remoteLinkAddr + e.setStateLocked(Stale) + e.notifyCompletionLocked(true /* succeeded */) + e.dispatchChangeEventLocked() + case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr @@ -404,6 +423,8 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { // not be possible. SEND uses RSA key pairs to produce Cryptographically // Generated Addresses (CGA), as defined in RFC 3972. This ensures that the // claimed source of an NDP message is the owner of the claimed address. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { switch e.neigh.State { case Incomplete: @@ -422,7 +443,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) // "Note that the Override flag is ignored if the entry is in the // INCOMPLETE state." - RFC 4861 section 7.2.5 @@ -457,7 +478,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) - e.notifyWakersLocked() + e.notifyCompletionLocked(true /* succeeded */) if !wasReachable { e.dispatchChangeEventLocked() } @@ -495,6 +516,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // handleUpperLevelConfirmationLocked processes an incoming upper-level protocol // (e.g. TCP acknowledgements) reachability confirmation. +// +// Precondition: e.mu MUST be locked. func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: @@ -512,23 +535,3 @@ func (e *neighborEntry) handleUpperLevelConfirmationLocked() { panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) } } - -// doubleLock combines two locks into one while maintaining lock ordering. -// -// TODO(gvisor.dev/issue/4796): Remove this once subsequent traffic to a Failed -// neighbor is allowed. -type doubleLock struct { - first, second sync.Locker -} - -// Lock locks both locks in order: first then second. -func (l *doubleLock) Lock() { - l.first.Lock() - l.second.Lock() -} - -// Unlock unlocks both locks in reverse order: second then first. -func (l *doubleLock) Unlock() { - l.second.Unlock() - l.first.Unlock() -} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index c497d3932..ec34ffa5a 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -25,7 +25,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -73,36 +72,36 @@ func eventDiffOptsWithSort() []cmp.Option { // The following unit tests exercise every state transition and verify its // behavior with RFC 4681. // -// | From | To | Cause | Action | Event | -// | ========== | ========== | ========================================== | =============== | ======= | -// | Unknown | Unknown | Confirmation w/ unknown address | | Added | -// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | -// | Unknown | Stale | Probe w/ unknown address | | Added | -// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | -// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | -// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | -// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | -// | Reachable | Stale | Reachable timer expired | | Changed | -// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | -// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | -// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | -// | Stale | Delay | Packet queued | | Changed | -// | Delay | Reachable | Upper-layer confirmation | | Changed | -// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Delay | Stale | Probe or confirmation w/ different address | | Changed | -// | Delay | Probe | Delay timer expired | Send probe | Changed | -// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | -// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | -// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | -// | Probe | Stale | Probe or confirmation w/ different address | | Changed | -// | Probe | Probe | Retransmit timer expired | Send probe | Changed | -// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | -// | Failed | Failed | Packet queued | | | -// | Failed | | Unreachability timer expired | Delete entry | | +// | From | To | Cause | Update | Action | Event | +// | ========== | ========== | ========================================== | ======== | ===========| ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | | Send probe | Added | +// | Unknown | Stale | Probe | | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Unsolicited confirmation | LinkAddr | Notify | Changed | +// | Incomplete | Stale | Probe | LinkAddr | Notify | Changed | +// | Incomplete | Failed | Max probes sent without reply | | Notify | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | IsRouter | | | +// | Reachable | Stale | Reachable timer expired | | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | | Changed | +// | Stale | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Stale | Stale | Override confirmation | LinkAddr | | Changed | +// | Stale | Stale | Probe w/ different address | LinkAddr | | Changed | +// | Stale | Delay | Packet sent | | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | | Changed | +// | Delay | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | | Changed | +// | Delay | Probe | Delay timer expired | | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | LinkAddr | | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | | Notify | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | | Notify | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | | Changed | +// | Probe | Probe | Retransmit timer expired | | | Changed | +// | Probe | Failed | Max probes sent without reply | | Notify | Removed | +// | Failed | Incomplete | Packet queued | | Send probe | Added | type testEntryEventType uint8 @@ -258,8 +257,8 @@ func TestEntryInitiallyUnknown(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -291,8 +290,8 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Unknown; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Unknown { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Unknown) } e.mu.Unlock() @@ -320,8 +319,8 @@ func TestEntryUnknownToIncomplete(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -367,8 +366,8 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -406,8 +405,8 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() @@ -560,21 +559,15 @@ func TestEntryIncompleteToReachable(t *testing.T) { nudDisp.mu.Unlock() } -// TestEntryAddsAndClearsWakers verifies that wakers are added when -// addWakerLocked is called and cleared when address resolution finishes. In -// this case, address resolution will finish when transitioning from Incomplete -// to Reachable. -func TestEntryAddsAndClearsWakers(t *testing.T) { +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) - w := sleep.Waker{} - s := sleep.Sleeper{} - s.AddWaker(&w, 123) - defer s.Done() - e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) + } e.mu.Unlock() runImmediatelyScheduledJobs(clock) @@ -593,26 +586,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Lock() - if got := e.wakers; got != nil { - t.Errorf("got e.wakers = %v, want = nil", got) - } - e.addWakerLocked(&w) - if got, want := w.IsAsserted(), false; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) - } - if e.wakers == nil { - t.Error("expected e.wakers to be non-nil") - } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, - IsRouter: false, + IsRouter: true, }) - if e.wakers != nil { - t.Errorf("got e.wakers = %v, want = nil", e.wakers) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } - if got, want := w.IsAsserted(), true; got != want { - t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) } e.mu.Unlock() @@ -643,7 +626,7 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { +func TestEntryIncompleteToStaleWhenUnsolicitedConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -663,22 +646,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { }, } linkRes.mu.Lock() - if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - linkRes.mu.Unlock() e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, + Solicited: false, Override: false, - IsRouter: true, + IsRouter: false, }) - if e.neigh.State != Reachable { - t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) - } - if !e.isRouter { - t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -698,7 +679,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { Entry: NeighborEntry{ Addr: entryTestAddr1, LinkAddr: entryTestLinkAddr1, - State: Reachable, + State: Stale, }, }, } @@ -709,7 +690,7 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryIncompleteToStale(t *testing.T) { +func TestEntryIncompleteToStaleWhenProbe(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, clock := entryTestSetup(c) @@ -736,11 +717,7 @@ func TestEntryIncompleteToStale(t *testing.T) { } e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) + e.handleProbeLocked(entryTestLinkAddr1) if e.neigh.State != Stale { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } @@ -780,8 +757,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() @@ -841,8 +818,8 @@ func TestEntryIncompleteToFailed(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } e.mu.Unlock() } @@ -885,8 +862,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { Override: false, IsRouter: true, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.isRouter, true; got != want { t.Errorf("got e.isRouter = %t, want = %t", got, want) @@ -932,8 +909,8 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() } @@ -1083,8 +1060,8 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() } @@ -2381,8 +2358,8 @@ func TestEntryDelayToProbe(t *testing.T) { IsRouter: false, }) e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } e.mu.Unlock() @@ -2447,8 +2424,8 @@ func TestEntryDelayToProbe(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.mu.Unlock() } @@ -2505,12 +2482,12 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2620,16 +2597,16 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) } e.mu.Unlock() @@ -2740,16 +2717,16 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2836,16 +2813,16 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -2964,16 +2941,16 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ Solicited: true, Override: true, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) @@ -3101,16 +3078,16 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin } e.mu.Lock() - if got, want := e.neigh.State, Probe; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) } e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, IsRouter: false, }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) } e.mu.Unlock() @@ -3435,212 +3412,61 @@ func TestEntryProbeToFailed(t *testing.T) { nudDisp.mu.Unlock() } -func TestEntryFailedToFailed(t *testing.T) { +func TestEntryFailedToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 e, nudDisp, linkRes, clock := entryTestSetup(c) - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) - } - // TODO(gvisor.dev/issue/4872): Use helper functions to start entry tests in // their expected state. e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } - - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - wantEvents := []testEntryEventInfo{ - { - EventType: entryTestAdded, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, { - EventType: entryTestRemoved, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, - }, + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, }, } - nudDisp.mu.Lock() - if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } - nudDisp.mu.Unlock() - - failedLookups := e.nic.stats.Neighbor.FailedEntryLookups - if got := failedLookups.Value(); got != 0 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 0", got) + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } e.mu.Lock() - // Verify queuing a packet to the entry immediately fails. - e.handlePacketQueuedLocked(entryTestAddr2) - state := e.neigh.State - e.mu.Unlock() - if state != Failed { - t.Errorf("got e.neigh.State = %q, want = %q", state, Failed) - } - - if got := failedLookups.Value(); got != 1 { - t.Errorf("got Neighbor.FailedEntryLookups = %d, want = 1", got) - } -} - -func TestEntryFailedGetsDeleted(t *testing.T) { - c := DefaultNUDConfigurations() - c.MaxMulticastProbes = 3 - c.MaxUnicastProbes = 3 - e, nudDisp, linkRes, clock := entryTestSetup(c) - - // Verify the cache contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { - t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } - - e.mu.Lock() - e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - runImmediatelyScheduledJobs(clock) - { - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.probes = nil - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - } - e.mu.Lock() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) e.handlePacketQueuedLocked(entryTestAddr2) - e.mu.Unlock() - - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime - clock.Advance(waitFor) - { - wantProbes := []entryTestProbeInfo{ - // The next three probe are sent in Probe. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { @@ -3653,39 +3479,21 @@ func TestEntryFailedGetsDeleted(t *testing.T) { }, }, { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, - }, - }, - { - EventType: entryTestChanged, - NICID: entryTestNICID, - Entry: NeighborEntry{ - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, - }, - }, - { - EventType: entryTestChanged, + EventType: entryTestRemoved, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, { - EventType: entryTestRemoved, + EventType: entryTestAdded, NICID: entryTestNICID, Entry: NeighborEntry{ Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, }, }, } @@ -3694,9 +3502,4 @@ func TestEntryFailedGetsDeleted(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - // Verify the cache no longer contains the entry. - if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { - t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) - } } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 5887aa1ed..4a34805b5 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -20,7 +20,6 @@ import ( "reflect" "sync/atomic" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -54,9 +53,9 @@ type NIC struct { sync.RWMutex spoofing bool promiscuous bool - // packetEPs is protected by mu, but the contained PacketEndpoint - // values are not. - packetEPs map[tcpip.NetworkProtocolNumber][]PacketEndpoint + // packetEPs is protected by mu, but the contained packetEndpointList are + // not. + packetEPs map[tcpip.NetworkProtocolNumber]*packetEndpointList } } @@ -82,6 +81,39 @@ type DirectionStats struct { Bytes *tcpip.StatCounter } +type packetEndpointList struct { + mu sync.RWMutex + + // eps is protected by mu, but the contained PacketEndpoint values are not. + eps []PacketEndpoint +} + +func (p *packetEndpointList) add(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + p.eps = append(p.eps, ep) +} + +func (p *packetEndpointList) remove(ep PacketEndpoint) { + p.mu.Lock() + defer p.mu.Unlock() + for i, epOther := range p.eps { + if epOther == ep { + p.eps = append(p.eps[:i], p.eps[i+1:]...) + break + } + } +} + +// forEach calls fn with each endpoints in p while holding the read lock on p. +func (p *packetEndpointList) forEach(fn func(PacketEndpoint)) { + p.mu.RLock() + defer p.mu.RUnlock() + for _, ep := range p.eps { + fn(ep) + } +} + // newNIC returns a new NIC using the default NDP configurations from stack. func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICContext) *NIC { // TODO(b/141011931): Validate a LinkEndpoint (ep) is valid. For @@ -102,7 +134,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC stats: makeNICStats(), networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint), } - nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber][]PacketEndpoint) + nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList) // Check for Neighbor Unreachability Detection support. var nud NUDHandler @@ -125,11 +157,11 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC // Register supported packet and network endpoint protocols. for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} + nic.mu.packetEPs[netProto] = new(packetEndpointList) } for _, netProto := range stack.networkProtocols { netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil + nic.mu.packetEPs[netNum] = new(packetEndpointList) nic.networkEndpoints[netNum] = netProto.NewEndpoint(nic, stack, nud, nic) } @@ -172,7 +204,7 @@ func (n *NIC) disable() { // // n MUST be locked. func (n *NIC) disableLocked() { - if !n.setEnabled(false) { + if !n.Enabled() { return } @@ -184,6 +216,10 @@ func (n *NIC) disableLocked() { for _, ep := range n.networkEndpoints { ep.Disable() } + + if !n.setEnabled(false) { + panic("should have only done work to disable the NIC if it was enabled") + } } // enable enables n. @@ -258,15 +294,17 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // the same unresolved IP address, and transmit the saved // packet when the address has been resolved. // - // RFC 4861 section 5.2 (for IPv6): - // Once the IP address of the next-hop node is known, the sender - // examines the Neighbor Cache for link-layer information about that - // neighbor. If no entry exists, the sender creates one, sets its state - // to INCOMPLETE, initiates Address Resolution, and then queues the data - // packet pending completion of address resolution. + // RFC 4861 section 7.2.2 (for IPv6): + // While waiting for address resolution to complete, the sender MUST, for + // each neighbor, retain a small queue of packets waiting for address + // resolution to complete. The queue MUST hold at least one packet, and MAY + // contain more. However, the number of queued packets per neighbor SHOULD + // be limited to some small value. When a queue overflows, the new arrival + // SHOULD replace the oldest entry. Once address resolution completes, the + // node transmits any queued packets. if ch, err := r.Resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { - r := r.Clone() + r.Acquire() n.stack.linkResQueue.enqueue(ch, r, protocol, pkt) return nil } @@ -279,7 +317,9 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb // WritePacketToRemote implements NetworkInterface. func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { r := Route{ - NetProto: protocol, + routeInfo: routeInfo{ + NetProto: protocol, + }, } r.ResolveWith(remoteLinkAddr) return n.writePacket(&r, gso, protocol, pkt) @@ -508,14 +548,6 @@ func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { return n.neigh.entries(), nil } -func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { - if n.neigh == nil { - return - } - - n.neigh.removeWaker(addr, w) -} - func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { if n.neigh == nil { return tcpip.ErrNotSupported @@ -634,15 +666,23 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? - packetEPs := n.mu.packetEPs[protocol] - // Add any other packet type sockets that may be listening for all protocols. - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) + protoEPs := n.mu.packetEPs[protocol] + // Other packet type sockets that are listening for all protocols. + anyEPs := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + // Deliver to interested packet endpoints without holding NIC lock. + deliverPacketEPs := func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketHost ep.HandlePacket(n.id, local, protocol, p) } + if protoEPs != nil { + protoEPs.forEach(deliverPacketEPs) + } + if anyEPs != nil { + anyEPs.forEach(deliverPacketEPs) + } // Parse headers. netProto := n.stack.NetworkProtocolInstance(protocol) @@ -683,16 +723,17 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // We do not deliver to protocol specific packet endpoints as on Linux // only ETH_P_ALL endpoints get outbound packets. // Add any other packet sockets that maybe listening for all protocols. - packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + eps := n.mu.packetEPs[header.EthernetProtocolAll] n.mu.RUnlock() - for _, ep := range packetEPs { + + eps.forEach(func(ep PacketEndpoint) { p := pkt.Clone() p.PktType = tcpip.PacketOutgoing // Add the link layer header as outgoing packets are intercepted // before the link layer header is created. n.LinkEndpoint.AddHeader(local, remote, protocol, p) ep.HandlePacket(n.id, local, protocol, p) - } + }) } // DeliverTransportPacket delivers the packets to the appropriate transport @@ -845,7 +886,7 @@ func (n *NIC) registerPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep Pa if !ok { return tcpip.ErrNotSupported } - n.mu.packetEPs[netProto] = append(eps, ep) + eps.add(ep) return nil } @@ -858,13 +899,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep if !ok { return } - - for i, epOther := range eps { - if epOther == ep { - n.mu.packetEPs[netProto] = append(eps[:i], eps[i+1:]...) - return - } - } + eps.remove(ep) } // isValidForOutgoing returns true if the endpoint can be used to send out a diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index ab629b3a4..12d67409a 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -109,14 +109,6 @@ const ( // // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. defaultMaxReachbilityConfirmations = 3 - - // defaultUnreachableTime is the default duration for how long an entry will - // remain in the FAILED state before being removed from the neighbor cache. - // - // Note, there is no equivalent protocol constant defined in RFC 4861. It - // leaves the specifics of any garbage collection mechanism up to the - // implementation. - defaultUnreachableTime = 5 * time.Second ) // NUDDispatcher is the interface integrators of netstack must implement to @@ -278,10 +270,6 @@ type NUDConfigurations struct { // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD // configuration option is necessary. MaxReachabilityConfirmations uint32 - - // UnreachableTime describes how long an entry will remain in the FAILED - // state before being removed from the neighbor cache. - UnreachableTime time.Duration } // DefaultNUDConfigurations returns a NUDConfigurations populated with default @@ -299,7 +287,6 @@ func DefaultNUDConfigurations() NUDConfigurations { MaxUnicastProbes: defaultMaxUnicastProbes, MaxAnycastDelayTime: defaultMaxAnycastDelayTime, MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, - UnreachableTime: defaultUnreachableTime, } } @@ -329,9 +316,6 @@ func (c *NUDConfigurations) resetInvalidFields() { if c.MaxUnicastProbes == 0 { c.MaxUnicastProbes = defaultMaxUnicastProbes } - if c.UnreachableTime == 0 { - c.UnreachableTime = defaultUnreachableTime - } } // calcMaxRandomFactor calculates the maximum value of the random factor used @@ -416,7 +400,7 @@ func (s *NUDState) ReachableTime() time.Duration { s.config.BaseReachableTime != s.prevBaseReachableTime || s.config.MinRandomFactor != s.prevMinRandomFactor || s.config.MaxRandomFactor != s.prevMaxRandomFactor { - return s.recomputeReachableTimeLocked() + s.recomputeReachableTimeLocked() } return s.reachableTime } @@ -442,7 +426,7 @@ func (s *NUDState) ReachableTime() time.Duration { // random value gets re-computed at least once every few hours. // // s.mu MUST be locked for writing. -func (s *NUDState) recomputeReachableTimeLocked() time.Duration { +func (s *NUDState) recomputeReachableTimeLocked() { s.prevBaseReachableTime = s.config.BaseReachableTime s.prevMinRandomFactor = s.config.MinRandomFactor s.prevMaxRandomFactor = s.config.MaxRandomFactor @@ -462,5 +446,4 @@ func (s *NUDState) recomputeReachableTimeLocked() time.Duration { } s.expiration = time.Now().Add(2 * time.Hour) - return s.reachableTime } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 8cffb9fc6..7bca1373e 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -37,7 +37,6 @@ const ( defaultMaxUnicastProbes = 3 defaultMaxAnycastDelayTime = time.Second defaultMaxReachbilityConfirmations = 3 - defaultUnreachableTime = 5 * time.Second defaultFakeRandomNum = 0.5 ) @@ -565,58 +564,6 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { } } -func TestNUDConfigurationsUnreachableTime(t *testing.T) { - tests := []struct { - name string - unreachableTime time.Duration - want time.Duration - }{ - // Invalid cases - { - name: "EqualToZero", - unreachableTime: 0, - want: defaultUnreachableTime, - }, - // Valid cases - { - name: "MoreThanZero", - unreachableTime: time.Millisecond, - want: time.Millisecond, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - const nicID = 1 - - c := stack.DefaultNUDConfigurations() - c.UnreachableTime = test.unreachableTime - - e := channel.New(0, 1280, linkAddr1) - e.LinkEPCapabilities |= stack.CapabilityResolutionRequired - - s := stack.New(stack.Options{ - // A neighbor cache is required to store NUDConfigurations. The - // networking stack will only allocate neighbor caches if a protocol - // providing link address resolution is specified (e.g. ARP or IPv6). - NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocol}, - NUDConfigs: c, - UseNeighborCache: true, - }) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - sc, err := s.NUDConfigurations(nicID) - if err != nil { - t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) - } - if got := sc.UnreachableTime; got != test.want { - t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) - } - }) - } -} - // TestNUDStateReachableTime verifies the correctness of the ReachableTime // computation. func TestNUDStateReachableTime(t *testing.T) { diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index 5d364a2b0..4a3adcf33 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -103,7 +103,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro for _, p := range packets { if cancelled { p.route.Stats().IP.OutgoingPacketErrors.Increment() - } else if _, err := p.route.Resolve(nil); err != nil { + } else if p.route.IsResolutionRequired() { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index b334e27c4..7e83b7fbb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -799,19 +798,26 @@ type LinkAddressCache interface { // AddLinkAddress adds a link address to the cache. AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) - // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). - // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver - // registered with the network protocol, the cache attempts to resolve the address - // and returns ErrWouldBlock. Waker is notified when address resolution is - // complete (success or not). + // GetLinkAddress finds the link address corresponding to the remote address + // (e.g. IP -> MAC). // - // If address resolution is required, ErrNoLinkAddress and a notification channel is - // returned for the top level caller to block. Channel is closed once address resolution - // is complete (success or not). - GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) - - // RemoveWaker removes a waker that has been added in GetLinkAddress(). - RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) + // Returns a link address for the remote address, if readily available. + // + // Returns ErrWouldBlock if the link address is not readily available, along + // with a notification channel for the caller to block on. Triggers address + // resolution asynchronously. + // + // If onResolve is provided, it will be called either immediately, if + // resolution is not required, or when address resolution is complete, with + // the resolved link address and whether resolution succeeded. After any + // callbacks have been called, the returned notification channel is closed. + // + // If specified, the local address must be an address local to the interface + // the neighbor cache belongs to. The local address is the source address of + // a packet prompting NUD/link address resolution. + // + // TODO(gvisor.dev/issue/5151): Don't return the link address. + GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) } // RawFactory produces endpoints for writing various types of raw packets. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index de5fe6ffe..b0251d0b4 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -17,7 +17,6 @@ package stack import ( "fmt" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -31,24 +30,7 @@ import ( // // TODO(gvisor.dev/issue/4902): Unexpose immutable fields. type Route struct { - // RemoteAddress is the final destination of the route. - RemoteAddress tcpip.Address - - // LocalAddress is the local address where the route starts. - LocalAddress tcpip.Address - - // LocalLinkAddress is the link-layer (MAC) address of the - // where the route starts. - LocalLinkAddress tcpip.LinkAddress - - // NextHop is the next node in the path to the destination. - NextHop tcpip.Address - - // NetProto is the network-layer protocol. - NetProto tcpip.NetworkProtocolNumber - - // Loop controls where WritePacket should send packets. - Loop PacketLooping + routeInfo // localAddressNIC is the interface the address is associated with. // TODO(gvisor.dev/issue/4548): Remove this field once we can query the @@ -78,6 +60,45 @@ type Route struct { linkRes LinkAddressResolver } +type routeInfo struct { + // RemoteAddress is the final destination of the route. + RemoteAddress tcpip.Address + + // LocalAddress is the local address where the route starts. + LocalAddress tcpip.Address + + // LocalLinkAddress is the link-layer (MAC) address of the + // where the route starts. + LocalLinkAddress tcpip.LinkAddress + + // NextHop is the next node in the path to the destination. + NextHop tcpip.Address + + // NetProto is the network-layer protocol. + NetProto tcpip.NetworkProtocolNumber + + // Loop controls where WritePacket should send packets. + Loop PacketLooping +} + +// RouteInfo contains all of Route's exported fields. +type RouteInfo struct { + routeInfo + + // RemoteLinkAddress is the link-layer (MAC) address of the next hop in the + // route. + RemoteLinkAddress tcpip.LinkAddress +} + +// GetFields returns a RouteInfo with all of r's exported fields. This allows +// callers to store the route's fields without retaining a reference to it. +func (r *Route) GetFields() RouteInfo { + return RouteInfo{ + routeInfo: r.routeInfo, + RemoteLinkAddress: r.RemoteLinkAddress(), + } +} + // constructAndValidateRoute validates and initializes a route. It takes // ownership of the provided local address. // @@ -152,13 +173,15 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) *Route { r := &Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - localAddressNIC: localAddressNIC, - outgoingNIC: outgoingNIC, - Loop: loop, + routeInfo: routeInfo{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + Loop: loop, + }, + localAddressNIC: localAddressNIC, + outgoingNIC: outgoingNIC, } r.mu.Lock() @@ -264,22 +287,21 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) { r.mu.remoteLinkAddress = addr } -// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in -// case address resolution requires blocking, e.g. wait for ARP reply. Waker is -// notified when address resolution is complete (success or not). +// Resolve attempts to resolve the link address if necessary. // -// If address resolution is required, ErrNoLinkAddress and a notification channel is -// returned for the top level caller to block. Channel is closed once address resolution -// is complete (success or not). -// -// The NIC r uses must not be locked. -func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { +// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g. +// waiting for ARP reply). If address resolution is required, a notification +// channel is also returned for the caller to block on. The channel is closed +// once address resolution is complete (successful or not). If a callback is +// provided, it will be called when address resolution is complete, regardless +// of success or failure. +func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) { r.mu.Lock() - defer r.mu.Unlock() if !r.isResolutionRequiredRLocked() { // Nothing to do if there is no cache (which does the resolution on cache miss) or // link address is already known. + r.mu.Unlock() return nil, nil } @@ -288,6 +310,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { // Local link address is already known. if r.RemoteAddress == r.LocalAddress { r.mu.remoteLinkAddress = r.LocalLinkAddress + r.mu.Unlock() return nil, nil } nextAddr = r.RemoteAddress @@ -300,38 +323,36 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { linkAddressResolutionRequestLocalAddr = r.LocalAddress } + // Increment the route's reference count because finishResolution retains a + // reference to the route and releases it when called. + r.acquireLocked() + r.mu.Unlock() + + finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) { + if ok { + r.ResolveWith(linkAddress) + } + if afterResolve != nil { + afterResolve() + } + r.Release() + } + if neigh := r.outgoingNIC.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) + _, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = entry.LinkAddr return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) + _, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, finishResolution) if err != nil { return ch, err } - r.mu.remoteLinkAddress = linkAddr return nil, nil } -// RemoveWaker removes a waker that has been added in Resolve(). -func (r *Route) RemoveWaker(waker *sleep.Waker) { - nextAddr := r.NextHop - if nextAddr == "" { - nextAddr = r.RemoteAddress - } - - if neigh := r.outgoingNIC.neigh; neigh != nil { - neigh.removeWaker(nextAddr, waker) - return - } - - r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) -} - // local returns true if the route is a local route. func (r *Route) local() bool { return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() @@ -419,46 +440,31 @@ func (r *Route) MTU() uint32 { return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } -// Release frees all resources associated with the route. +// Release decrements the reference counter of the resources associated with the +// route. func (r *Route) Release() { r.mu.Lock() defer r.mu.Unlock() - if r.mu.localAddressEndpoint != nil { - r.mu.localAddressEndpoint.DecRef() - r.mu.localAddressEndpoint = nil + if ep := r.mu.localAddressEndpoint; ep != nil { + ep.DecRef() } } -// Clone clones the route. -func (r *Route) Clone() *Route { +// Acquire increments the reference counter of the resources associated with the +// route. +func (r *Route) Acquire() { r.mu.RLock() defer r.mu.RUnlock() + r.acquireLocked() +} - newRoute := &Route{ - RemoteAddress: r.RemoteAddress, - LocalAddress: r.LocalAddress, - LocalLinkAddress: r.LocalLinkAddress, - NextHop: r.NextHop, - NetProto: r.NetProto, - Loop: r.Loop, - localAddressNIC: r.localAddressNIC, - outgoingNIC: r.outgoingNIC, - linkCache: r.linkCache, - linkRes: r.linkRes, - } - - newRoute.mu.Lock() - defer newRoute.mu.Unlock() - newRoute.mu.localAddressEndpoint = r.mu.localAddressEndpoint - if newRoute.mu.localAddressEndpoint != nil { - if !newRoute.mu.localAddressEndpoint.IncRef() { - panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", newRoute.LocalAddress)) +func (r *Route) acquireLocked() { + if ep := r.mu.localAddressEndpoint; ep != nil { + if !ep.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) } } - newRoute.mu.remoteLinkAddress = r.mu.remoteLinkAddress - - return newRoute } // Stack returns the instance of the Stack that owns this route. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index dc4f5b3e7..114643b03 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -29,7 +29,6 @@ import ( "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/rand" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -171,6 +170,9 @@ type TCPSenderState struct { // Outstanding is the number of packets in flight. Outstanding int + // SackedOut is the number of packets which have been selectively acked. + SackedOut int + // SndWnd is the send window size in bytes. SndWnd seqnum.Size @@ -1517,7 +1519,7 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, addr tcpip.Address, linkAddr t } // GetLinkAddress implements LinkAddressCache.GetLinkAddress. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { s.mu.RLock() nic := s.nics[nicID] if nic == nil { @@ -1528,7 +1530,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, onResolve) } // Neighbors returns all IP to MAC address associations. @@ -1544,29 +1546,6 @@ func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { return nic.neighbors() } -// RemoveWaker removes a waker that has been added when link resolution for -// addr was requested. -func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { - if s.useNeighborCache { - s.mu.RLock() - nic, ok := s.nics[nicID] - s.mu.RUnlock() - - if ok { - nic.removeWaker(addr, waker) - } - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - if nic := s.nics[nicID]; nic == nil { - fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} - s.linkAddrCache.removeWaker(fullAddr, waker) - } -} - // AddStaticNeighbor statically associates an IP address to a MAC address. func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { s.mu.RLock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 457990945..856ebf6d4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1602,7 +1602,10 @@ func TestOutgoingBroadcastWithEmptyRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: header.IPv4Any, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = header.IPv4Any + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", header.IPv4Any, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1656,7 +1659,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(1, %v, %v, %d) failed: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + var wantRoute stack.Route + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(1, %v, %v, %d) returned unexpected Route: %v", nic1Addr.Address, header.IPv4Broadcast, fakeNetNumber, err) } @@ -1666,7 +1672,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic2Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic2Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } @@ -1682,7 +1691,10 @@ func TestOutgoingBroadcastWithRouteTable(t *testing.T) { if err != nil { t.Fatalf("FindRoute(0, \"\", %s, %d) failed: %s", header.IPv4Broadcast, fakeNetNumber, err) } - if err := verifyRoute(r, &stack.Route{LocalAddress: nic1Addr.Address, RemoteAddress: header.IPv4Broadcast}); err != nil { + wantRoute = stack.Route{} + wantRoute.LocalAddress = nic1Addr.Address + wantRoute.RemoteAddress = header.IPv4Broadcast + if err := verifyRoute(r, &wantRoute); err != nil { t.Errorf("FindRoute(0, \"\", %s, %d) returned unexpected Route: %s)", header.IPv4Broadcast, fakeNetNumber, err) } } @@ -2726,8 +2738,16 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") - nicID = 1 - lifetimeSeconds = 9999 + globalAddr3 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03") + ipv4MappedIPv6Addr1 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x01") + ipv4MappedIPv6Addr2 = tcpip.Address("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x02") + toredoAddr1 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + toredoAddr2 = tcpip.Address("\x20\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + ipv6ToIPv4Addr1 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + ipv6ToIPv4Addr2 = tcpip.Address("\x20\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + nicID = 1 + lifetimeSeconds = 9999 ) prefix1, _, stableGlobalAddr1 := prefixSubnetAddr(0, linkAddr1) @@ -2744,139 +2764,191 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { slaacPrefixForTempAddrBeforeNICAddrAdd tcpip.AddressWithPrefix nicAddrs []tcpip.Address slaacPrefixForTempAddrAfterNICAddrAdd tcpip.AddressWithPrefix - connectAddr tcpip.Address + remoteAddr tcpip.Address expectedLocalAddr tcpip.Address }{ - // Test Rule 1 of RFC 6724 section 5. + // Test Rule 1 of RFC 6724 section 5 (prefer same address). { name: "Same Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: globalAddr1, expectedLocalAddr: globalAddr1, }, { name: "Same Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr1, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr1, expectedLocalAddr: linkLocalAddr1, }, { name: "Same Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Same Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr1, + nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr1, expectedLocalAddr: uniqueLocalAddr1, }, - // Test Rule 2 of RFC 6724 section 5. + // Test Rule 2 of RFC 6724 section 5 (prefer appropriate scope). { name: "Global most preferred (last address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Global most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: globalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: globalAddr2, expectedLocalAddr: globalAddr1, }, { name: "Link Local most preferred (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalAddr2, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (last address)", - nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local most preferred for link local multicast (first address)", - nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: linkLocalMulticastAddr, + nicAddrs: []tcpip.Address{linkLocalAddr1, globalAddr1}, + remoteAddr: linkLocalMulticastAddr, expectedLocalAddr: linkLocalAddr1, }, + + // Test Rule 6 of 6724 section 5 (prefer matching label). { name: "Unique Local most preferred (last address)", - nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Unique Local most preferred (first address)", - nicAddrs: []tcpip.Address{globalAddr1, linkLocalAddr1, uniqueLocalAddr1}, - connectAddr: uniqueLocalAddr2, + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, toredoAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1}, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, + { + name: "Toredo most preferred (first address)", + nicAddrs: []tcpip.Address{toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "Toredo most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1}, + remoteAddr: toredoAddr2, + expectedLocalAddr: toredoAddr1, + }, + { + name: "6To4 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "6To4 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv4MappedIPv6Addr1, uniqueLocalAddr1, toredoAddr1, ipv6ToIPv4Addr1}, + remoteAddr: ipv6ToIPv4Addr2, + expectedLocalAddr: ipv6ToIPv4Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (first address)", + nicAddrs: []tcpip.Address{ipv4MappedIPv6Addr1, ipv6ToIPv4Addr1, toredoAddr1, uniqueLocalAddr1, globalAddr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, + { + name: "IPv4 mapped IPv6 most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, ipv6ToIPv4Addr1, uniqueLocalAddr1, toredoAddr1, ipv4MappedIPv6Addr1}, + remoteAddr: ipv4MappedIPv6Addr2, + expectedLocalAddr: ipv4MappedIPv6Addr1, + }, - // Test Rule 7 of RFC 6724 section 5. + // Test Rule 7 of RFC 6724 section 5 (prefer temporary addresses). { name: "Temp Global most preferred (last address)", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, { name: "Temp Global most preferred (first address)", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1}, slaacPrefixForTempAddrAfterNICAddrAdd: prefix1, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: tempGlobalAddr1, }, + // Test Rule 8 of RFC 6724 section 5 (use longest matching prefix). + { + name: "Longest prefix matched most preferred (first address)", + nicAddrs: []tcpip.Address{globalAddr2, globalAddr1}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + { + name: "Longest prefix matched most preferred (last address)", + nicAddrs: []tcpip.Address{globalAddr1, globalAddr2}, + remoteAddr: globalAddr3, + expectedLocalAddr: globalAddr2, + }, + // Test returning the endpoint that is closest to the front when // candidate addresses are "equal" from the perspective of RFC 6724 // section 5. { name: "Unique Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, uniqueLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: uniqueLocalAddr1, }, { name: "Link Local for Global", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: globalAddr2, + remoteAddr: globalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Link Local for Unique Local", nicAddrs: []tcpip.Address{linkLocalAddr1, linkLocalAddr2}, - connectAddr: uniqueLocalAddr2, + remoteAddr: uniqueLocalAddr2, expectedLocalAddr: linkLocalAddr1, }, { name: "Temp Global for Global", slaacPrefixForTempAddrBeforeNICAddrAdd: prefix1, slaacPrefixForTempAddrAfterNICAddrAdd: prefix2, - connectAddr: globalAddr1, + remoteAddr: globalAddr1, expectedLocalAddr: tempGlobalAddr2, }, } @@ -2898,12 +2970,6 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - s.SetRouteTable([]tcpip.Route{{ - Destination: header.IPv6EmptySubnet, - Gateway: llAddr3, - NIC: nicID, - }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) if test.slaacPrefixForTempAddrBeforeNICAddrAdd != (tcpip.AddressWithPrefix{}) { e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, test.slaacPrefixForTempAddrBeforeNICAddrAdd, true, true, lifetimeSeconds, lifetimeSeconds)) @@ -2923,7 +2989,23 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { t.FailNow() } - if got := addrForNewConnectionTo(t, s, tcpip.FullAddress{Addr: test.connectAddr, NIC: nicID, Port: 1234}); got != test.expectedLocalAddr { + netEP, err := s.GetNetworkEndpoint(nicID, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } + + addressableEndpoint, ok := netEP.(stack.AddressableEndpoint) + if !ok { + t.Fatal("network endpoint is not addressable") + } + + addressEP := addressableEndpoint.AcquireOutgoingPrimaryAddress(test.remoteAddr, false /* allowExpired */) + if addressEP == nil { + t.Fatal("expected a non-nil address endpoint") + } + defer addressEP.DecRef() + + if got := addressEP.AddressWithPrefix().Address; got != test.expectedLocalAddr { t.Errorf("got local address = %s, want = %s", got, test.expectedLocalAddr) } }) @@ -4204,8 +4286,8 @@ func TestWritePacketToRemote(t *testing.T) { if got, want := pkt.Proto, test.protocol; got != want { t.Fatalf("pkt.Proto = %d, want %d", got, want) } - if got, want := pkt.Route.RemoteLinkAddress(), linkAddr2; got != want { - t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", got, want) + if pkt.Route.RemoteLinkAddress != linkAddr2 { + t.Fatalf("pkt.Route.RemoteAddress = %s, want %s", pkt.Route.RemoteLinkAddress, linkAddr2) } if diff := cmp.Diff(pkt.Pkt.Data.ToView(), buffer.View(test.payload)); diff != "" { t.Errorf("pkt.Pkt.Data mismatch (-want +got):\n%s", diff) diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 2cdb5ca79..737d8d912 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -141,11 +141,11 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NI // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: testSrcAddrV6, - DstAddr: testDstAddrV6, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: testSrcAddrV6, + DstAddr: testDstAddrV6, }) // Initialize the UDP header. @@ -308,9 +308,8 @@ func TestBindToDeviceDistribution(t *testing.T) { defer ep.Close() ep.SocketOptions().SetReusePort(endpoint.reuse) - bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) + if err := ep.SocketOptions().SetBindToDevice(int32(endpoint.bindToDevice)); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", endpoint.bindToDevice, endpoint.bindToDevice, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index d9769e47d..dd552b8b9 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -77,6 +77,7 @@ func (f *fakeTransportEndpoint) Abort() { } func (f *fakeTransportEndpoint) Close() { + // TODO(gvisor.dev/issue/5153): Consider retaining the route. f.route.Release() } @@ -109,8 +110,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. Currently not supported. @@ -146,16 +147,16 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return tcpip.ErrNoRoute } - defer r.Release() // Try to register so that we can start receiving packets. f.ID.RemoteAddress = addr.Addr err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */) if err != nil { + r.Release() return err } - f.route = r.Clone() + f.route = r return nil } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 2bd472811..ef0f51f1a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -49,8 +49,9 @@ const ipv4AddressSize = 4 // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // -// Note: to support save / restore, it is important that all tcpip errors have -// distinct error messages. +// All errors must have unique msg strings. +// +// +stateify savable type Error struct { msg string @@ -257,6 +258,44 @@ func (a Address) Unspecified() bool { return true } +// MatchingPrefix returns the matching prefix length in bits. +// +// Panics if b and a have different lengths. +func (a Address) MatchingPrefix(b Address) uint8 { + const bitsInAByte = 8 + + if len(a) != len(b) { + panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b)) + } + + var prefix uint8 + for i := range a { + aByte := a[i] + bByte := b[i] + + if aByte == bByte { + prefix += bitsInAByte + continue + } + + // Count the remaining matching bits in the byte from MSbit to LSBbit. + mask := uint8(1) << (bitsInAByte - 1) + for { + if aByte&mask == bByte&mask { + prefix++ + mask >>= 1 + continue + } + + break + } + + break + } + + return prefix +} + // AddressMask is a bitmask for an address. type AddressMask string @@ -491,6 +530,17 @@ type ControlMessages struct { // PacketInfo holds interface and address data on an incoming packet. PacketInfo IPPacketInfo + + // HasOriginalDestinationAddress indicates whether OriginalDstAddress is + // set. + HasOriginalDstAddress bool + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress FullAddress + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr *SockError } // PacketOwner is used to get UID and GID of the packet. @@ -545,7 +595,7 @@ type Endpoint interface { // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. - Peek([][]byte) (int64, ControlMessages, *Error) + Peek([][]byte) (int64, *Error) // Connect connects the endpoint to its peer. Specifying a NIC is // optional. @@ -905,14 +955,6 @@ type SettableSocketOption interface { isSettableSocketOption() } -// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets -// should bind only on a specific NIC. -type BindToDeviceOption NICID - -func (*BindToDeviceOption) isGettableSocketOption() {} - -func (*BindToDeviceOption) isSettableSocketOption() {} - // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -1087,14 +1129,6 @@ type RemoveMembershipOption MembershipOption func (*RemoveMembershipOption) isSettableSocketOption() {} -// OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether -// TCP out-of-band data is delivered along with the normal in-band data. -type OutOfBandInlineOption int - -func (*OutOfBandInlineOption) isGettableSocketOption() {} - -func (*OutOfBandInlineOption) isSettableSocketOption() {} - // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int @@ -1144,10 +1178,6 @@ type LingerOption struct { Timeout time.Duration } -func (*LingerOption) isGettableSocketOption() {} - -func (*LingerOption) isSettableSocketOption() {} - // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index c461da137..9bd563c46 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -270,3 +270,43 @@ func TestAddressUnspecified(t *testing.T) { }) } } + +func TestAddressMatchingPrefix(t *testing.T) { + tests := []struct { + addrA Address + addrB Address + prefix uint8 + }{ + { + addrA: "\x01\x01", + addrB: "\x01\x01", + prefix: 16, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x00", + prefix: 15, + }, + { + addrA: "\x01\x01", + addrB: "\x81\x00", + prefix: 0, + }, + { + addrA: "\x01\x01", + addrB: "\x01\x80", + prefix: 8, + }, + { + addrA: "\x01\x01", + addrB: "\x02\x80", + prefix: 6, + }, + } + + for _, test := range tests { + if got := test.addrA.MatchingPrefix(test.addrB); got != test.prefix { + t.Errorf("got (%s).MatchingPrefix(%s) = %d, want = %d", test.addrA, test.addrB, got, test.prefix) + } + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 8be791a00..2e59f6a42 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -96,11 +96,11 @@ func TestPingMulticastBroadcast(t *testing.T) { pkt.SetChecksum(header.ICMPv6Checksum(pkt, remoteIPv6Addr, dst, buffer.VectorisedView{})) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: header.ICMPv6MinimumSize, - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: header.ICMPv6MinimumSize, + TransportProtocol: icmp.ProtocolNumber6, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -272,11 +272,11 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLen), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: ttl, - SrcAddr: remoteIPv6Addr, - DstAddr: dst, + PayloadLength: uint16(payloadLen), + TransportProtocol: udp.ProtocolNumber, + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, }) e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 94fcd72d9..d1e4a7cb7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -75,8 +75,6 @@ type endpoint struct { route *stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -332,21 +330,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.SocketDetachFilterOption: - return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - } return nil } @@ -399,16 +388,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -524,7 +504,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: r.LocalAddress, @@ -539,11 +518,12 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, err = e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } e.ID = id - e.route = r.Clone() + e.route = r e.RegisterNICID = nicID e.state = stateConnected diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 3666bac0f..e5e247342 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -85,8 +85,6 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -206,8 +204,8 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha } // Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be @@ -306,16 +304,10 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - ep.mu.Lock() - ep.linger = *v - ep.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -374,18 +366,16 @@ func (ep *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (ep *endpoint) UpdateLastError(err *tcpip.Error) { + ep.lastErrorMu.Lock() + ep.lastError = err + ep.lastErrorMu.Unlock() +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - ep.mu.Lock() - *o = ep.linger - ep.mu.Unlock() - return nil - - default: - return tcpip.ErrNotSupported - } + return tcpip.ErrNotSupported } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 0840a4b3d..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -85,8 +85,6 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -227,6 +225,13 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrInvalidOptionValue } + if opts.To != nil { + // Raw sockets do not support sending to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(opts.To.Addr) != header.IPv6AddressSize { + return 0, nil, tcpip.ErrInvalidOptionValue + } + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -256,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -273,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -295,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -335,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -386,8 +364,8 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } // Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -397,6 +375,11 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. + if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { + return tcpip.ErrAddressFamilyNotSupported + } + e.mu.Lock() defer e.mu.Unlock() @@ -425,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -437,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil @@ -520,16 +503,10 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { + switch opt.(type) { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() - return nil - default: return tcpip.ErrUnknownProtocolOption } @@ -581,16 +558,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.mu.Lock() - *o = e.linger - e.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -625,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + e.mu.RLock() e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -637,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // sockets. if e.rcvClosed || !e.associated { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() return @@ -644,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return @@ -655,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() + e.mu.RUnlock() return } // If bound to an address, only accept data for that address. if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } } @@ -668,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // connected to. if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } @@ -702,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() + e.mu.RUnlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3e1041cbe..2d96a65bd 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -778,7 +778,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) }() - s := sleep.Sleeper{} + var s sleep.Sleeper s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) for { diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index c944dccc0..0dc710276 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -462,7 +462,7 @@ func (h *handshake) processSegments() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resolutionWaker := &sleep.Waker{} s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -470,24 +470,27 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution + attemptedResolution := false for { switch index { case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - if err == tcpip.ErrNoLinkAddress { - h.ep.stats.SendErrors.NoLinkAddr.Increment() - } else if err != nil { + if _, err := h.ep.route.Resolve(resolutionWaker.Assert); err != tcpip.ErrWouldBlock { + if err != nil { h.ep.stats.SendErrors.NoRoute.Increment() } // Either success (err == nil) or failure. return err } + if attemptedResolution { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + return tcpip.ErrNoLinkAddress + } + attemptedResolution = true // Resolution not completed. Keep trying... case wakerForNotification: n := h.ep.fetchNotifications() if n¬ifyClose != 0 { - h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } if n¬ifyDrain != 0 { @@ -563,7 +566,7 @@ func (h *handshake) start() *tcpip.Error { // complete completes the TCP 3-way handshake initiated by h.start(). func (h *handshake) complete() *tcpip.Error { // Set up the wakers. - s := sleep.Sleeper{} + var s sleep.Sleeper resendWaker := sleep.Waker{} s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&h.ep.notificationWaker, wakerForNotification) @@ -1512,7 +1515,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } // Initialize the sleeper based on the wakers in funcs. - s := sleep.Sleeper{} + var s sleep.Sleeper for i := range funcs { s.AddWaker(funcs[i].w, i) } @@ -1699,7 +1702,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { const notification = 2 const timeWaitDone = 3 - s := sleep.Sleeper{} + var s sleep.Sleeper defer s.Done() s.AddWaker(&e.newSegmentWaker, newSegment) s.AddWaker(&e.notificationWaker, notification) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 87eda2efb..6e3c8860e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -502,9 +502,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // bindToDevice is set to the NIC on which to bind or disabled if 0. - bindToDevice tcpip.NICID - // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -674,9 +671,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -1040,7 +1034,8 @@ func (e *endpoint) Close() { return } - if e.linger.Enabled && e.linger.Timeout == 0 { + linger := e.SocketOptions().GetLinger() + if linger.Enabled && linger.Timeout == 0 { s := e.EndpointState() isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv if isResetState { @@ -1305,6 +1300,15 @@ func (e *endpoint) LastError() *tcpip.Error { return e.lastErrorLocked() } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.LockUser() + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + e.UnlockUser() +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() @@ -1498,7 +1502,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // Peek reads data without consuming it from the endpoint. // // This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1506,10 +1510,10 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // but has some pending unread data. if s := e.EndpointState(); !s.connected() && s != StateClose { if s == StateError { - return 0, tcpip.ControlMessages{}, e.hardErrorLocked() + return 0, e.hardErrorLocked() } e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + return 0, tcpip.ErrInvalidEndpointState } e.rcvListMu.Lock() @@ -1518,9 +1522,9 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return 0, tcpip.ErrClosedForReceive } - return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return 0, tcpip.ErrWouldBlock } // Make a copy of vec so we can modify the slide headers. @@ -1535,7 +1539,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro for len(v) > 0 { if len(vec) == 0 { - return num, tcpip.ControlMessages{}, nil + return num, nil } if len(vec[0]) == 0 { vec = vec[1:] @@ -1550,7 +1554,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro } } - return num, tcpip.ControlMessages{}, nil + return num, nil } // selectWindowLocked returns the new window without checking for shrinking or scaling @@ -1814,18 +1818,13 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.LockUser() - e.bindToDevice = id - e.UnlockUser() - case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() e.keepalive.idle = time.Duration(*v) @@ -1838,9 +1837,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - case *tcpip.TCPUserTimeoutOption: e.LockUser() e.userTimeout = time.Duration(*v) @@ -1909,11 +1905,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { case *tcpip.SocketDetachFilterOption: return nil - case *tcpip.LingerOption: - e.LockUser() - e.linger = *v - e.UnlockUser() - default: return nil } @@ -2014,11 +2005,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case *tcpip.BindToDeviceOption: - e.LockUser() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.UnlockUser() - case *tcpip.TCPInfoOption: *o = tcpip.TCPInfoOption{} e.LockUser() @@ -2046,10 +2032,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { *o = tcpip.TCPUserTimeoutOption(e.userTimeout) e.UnlockUser() - case *tcpip.OutOfBandInlineOption: - // We don't currently support disabling this option. - *o = 1 - case *tcpip.CongestionControlOption: e.LockUser() *o = e.cc @@ -2078,11 +2060,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { Port: port, } - case *tcpip.LingerOption: - e.LockUser() - *o = e.linger - e.UnlockUser() - default: return tcpip.ErrUnknownProtocolOption } @@ -2230,11 +2207,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } } + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2272,15 +2250,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } id := e.ID id.LocalPort = p - if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) if err == tcpip.ErrPortInUse { return false, nil } @@ -2291,7 +2269,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // the selected port. e.ID = id e.isPortReserved = true - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags e.boundDest = addr return true, nil @@ -2302,7 +2280,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc e.isRegistered = true e.setEndpointState(StateConnecting) - e.route = r.Clone() + r.Acquire() + e.route = r e.boundNICID = nicID e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr @@ -2643,7 +2622,8 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2654,7 +2634,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // demuxer. Further connected endpoints always have a remote // address/port. Hence this will only return an error if there is a matching // listening endpoint. - if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { return false } return true @@ -2663,7 +2643,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - e.boundBindToDevice = e.bindToDevice + e.boundBindToDevice = bindToDevice e.boundPortFlags = e.portFlags // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. e.boundNICID = nic @@ -2727,6 +2707,41 @@ func (e *endpoint) enqueueSegment(s *segment) bool { return true } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + // Linux passes the payload with the TCP header. We don't know if the TCP + // header even exists, it may not for fragmented packets. + Payload: pkt.Data.ToView(), + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.notifyProtocolGoroutine(notifyError) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { switch typ { @@ -2741,16 +2756,10 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.notifyProtocolGoroutine(notifyMTUChanged) case stack.ControlNoRoute: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNoRoute - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNoRoute, id, byte(header.ICMPv4DstUnreachable), byte(header.ICMPv4HostUnreachable), extra, pkt) case stack.ControlNetworkUnreachable: - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrNetworkUnreachable - e.lastErrorMu.Unlock() - e.notifyProtocolGoroutine(notifyError) + e.onICMPError(tcpip.ErrNetworkUnreachable, id, byte(header.ICMPv6DstUnreachable), byte(header.ICMPv6NetworkUnreachable), extra, pkt) } } @@ -3008,6 +3017,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { Ssthresh: e.snd.sndSsthresh, SndCAAckCount: e.snd.sndCAAckCount, Outstanding: e.snd.outstanding, + SackedOut: e.snd.sackedOut, SndWnd: e.snd.sndWnd, SndUna: e.snd.sndUna, SndNxt: e.snd.sndNxt, diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index f2b1b68da..405a6dce7 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -172,14 +172,12 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // If we started off with a window larger than what can he held in // the 16bit window field, we ceil the value to the max value. - // While ceiling, we still do not want to grow the right edge when - // not applicable. if scaledWnd > math.MaxUint16 { - if toGrow { - scaledWnd = seqnum.Size(math.MaxUint16) - } else { - scaledWnd = seqnum.Size(uint16(scaledWnd)) - } + scaledWnd = seqnum.Size(math.MaxUint16) + + // Ensure that the stashed receive window always reflects what + // is being advertised. + r.rcvWnd = scaledWnd << r.rcvWndScale } return r.rcvNxt, scaledWnd } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index baec762e1..cc991aba6 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -137,6 +137,9 @@ type sender struct { // that have been sent but not yet acknowledged. outstanding int + // sackedOut is the number of packets which are selectively acked. + sackedOut int + // sndWnd is the send window size. sndWnd seqnum.Size @@ -372,6 +375,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { m = 1 } + oldMSS := s.maxPayloadSize s.maxPayloadSize = m if s.gso { s.ep.gso.MSS = uint16(m) @@ -394,6 +398,7 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { // Rewind writeNext to the first segment exceeding the MTU. Do nothing // if it is already before such a packet. + nextSeg := s.writeNext for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { if seg == s.writeNext { // We got to writeNext before we could find a segment @@ -401,16 +406,22 @@ func (s *sender) updateMaxPayloadSize(mtu, count int) { break } - if seg.data.Size() > m { + if nextSeg == s.writeNext && seg.data.Size() > m { // We found a segment exceeding the MTU. Rewind // writeNext and try to retransmit it. - s.writeNext = seg - break + nextSeg = seg + } + + if s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Update sackedOut for new maximum payload size. + s.sackedOut -= s.pCount(seg, oldMSS) + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } } // Since we likely reduced the number of outstanding packets, we may be // ready to send some more. + s.writeNext = nextSeg s.sendData() } @@ -629,13 +640,13 @@ func (s *sender) retransmitTimerExpired() bool { // pCount returns the number of packets in the segment. Due to GSO, a segment // can be composed of multiple packets. -func (s *sender) pCount(seg *segment) int { +func (s *sender) pCount(seg *segment, maxPayloadSize int) int { size := seg.data.Size() if size == 0 { return 1 } - return (size-1)/s.maxPayloadSize + 1 + return (size-1)/maxPayloadSize + 1 } // splitSeg splits a given segment at the size specified and inserts the @@ -1023,7 +1034,7 @@ func (s *sender) sendData() { break } dataSent = true - s.outstanding += s.pCount(seg) + s.outstanding += s.pCount(seg, s.maxPayloadSize) s.writeNext = seg.Next() } @@ -1038,6 +1049,7 @@ func (s *sender) enterRecovery() { // We inflate the cwnd by 3 to account for the 3 packets which triggered // the 3 duplicate ACKs and are now not in flight. s.sndCwnd = s.sndSsthresh + 3 + s.sackedOut = 0 s.fr.first = s.sndUna s.fr.last = s.sndNxt - 1 s.fr.maxCwnd = s.sndCwnd + s.outstanding @@ -1207,6 +1219,7 @@ func (s *sender) walkSACK(rcvdSeg *segment) { s.rc.update(seg, rcvdSeg, s.ep.tsOffset) s.rc.detectReorder(seg) seg.acked = true + s.sackedOut += s.pCount(seg, s.maxPayloadSize) } seg = seg.Next() } @@ -1380,10 +1393,10 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { datalen := seg.logicalLen() if datalen > ackLeft { - prevCount := s.pCount(seg) + prevCount := s.pCount(seg, s.maxPayloadSize) seg.data.TrimFront(int(ackLeft)) seg.sequenceNumber.UpdateForward(ackLeft) - s.outstanding -= prevCount - s.pCount(seg) + s.outstanding -= prevCount - s.pCount(seg, s.maxPayloadSize) break } @@ -1399,11 +1412,13 @@ func (s *sender) handleRcvdSegment(rcvdSeg *segment) { s.writeList.Remove(seg) - // If SACK is enabled then Only reduce outstanding if + // If SACK is enabled then only reduce outstanding if // the segment was not previously SACKED as these have // already been accounted for in SetPipe(). if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { - s.outstanding -= s.pCount(seg) + s.outstanding -= s.pCount(seg, s.maxPayloadSize) + } else { + s.sackedOut -= s.pCount(seg, s.maxPayloadSize) } seg.decRef() ackLeft -= datalen diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index ef7f5719f..faf0c0ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -590,3 +590,45 @@ func TestSACKRecovery(t *testing.T) { expected++ } } + +// TestSACKUpdateSackedOut tests the sacked out field is updated when a SACK +// is received. +func TestSACKUpdateSackedOut(t *testing.T) { + c := context.New(t, uint32(mtu)) + defer c.Cleanup() + + probeDone := make(chan struct{}) + ackNum := 0 + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint Sender.SackedOut is what we expect. + if state.Sender.SackedOut != 2 && ackNum == 0 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 2", state.Sender.SackedOut) + } + + if state.Sender.SackedOut != 0 && ackNum == 1 { + t.Fatalf("SackedOut got updated to wrong value got: %v want: 0", state.Sender.SackedOut) + } + if ackNum > 0 { + close(probeDone) + } + ackNum++ + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + sendAndReceive(t, c, 8) + + // ACK for [3-5] packets. + seq := seqnum.Value(context.TestInitialSequenceNumber).Add(1) + start := c.IRS.Add(seqnum.Size(1 + 3*maxPayload)) + bytesRead := 2 * maxPayload + end := start.Add(seqnum.Size(bytesRead)) + c.SendAckWithSACK(seq, bytesRead, []header.SACKBlock{{start, end}}) + + bytesRead += 3 * maxPayload + c.SendAck(seq, bytesRead) + + // Wait for the probe function to finish processing the ACK before the + // test completes. + <-probeDone +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 1759ebea9..cf60d5b53 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1380,9 +1380,8 @@ func TestConnectBindToDevice(t *testing.T) { defer c.Cleanup() c.Create(-1) - bindToDevice := tcpip.BindToDeviceOption(test.device) - if err := c.EP.SetSockOpt(&bindToDevice); err != nil { - t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + if err := c.EP.SocketOptions().SetBindToDevice(int32(test.device)); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", test.device, test.device, err) } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) @@ -1932,6 +1931,84 @@ func TestFullWindowReceive(t *testing.T) { ) } +// Test the stack receive window advertisement on receiving segments smaller than +// segment overhead. It tests for the right edge of the window to not grow when +// the endpoint is not being read from. +func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize, + Max: tcp.DefaultReceiveBufferSize << tcp.FindWndScale(seqnum.Size(tcp.DefaultReceiveBufferSize)), + } + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } + + c.AcceptWithOptions(tcp.FindWndScale(seqnum.Size(opt.Default)), header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Bump up the receive buffer size such that, when the receive window grows, + // the scaled window exceeds maxUint16. + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, opt.Max); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed: %s", opt.Max, err) + } + + // Keep the payload size < segment overhead and such that it is a multiple + // of the window scaled value. This enables the test to perform equality + // checks on the incoming receive window. + payload := generateRandomPayload(t, (tcp.SegSize-1)&(1<<c.RcvdWindowScale)) + payloadLen := seqnum.Size(len(payload)) + iss := seqnum.Value(789) + seqNum := iss.Add(1) + + // Send payload to the endpoint and return the advertised receive window + // from the endpoint. + getIncomingRcvWnd := func() uint32 { + c.SendPacket(payload, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: seqNum, + AckNum: c.IRS.Add(1), + Flags: header.TCPFlagAck, + RcvWnd: 30000, + }) + seqNum = seqNum.Add(payloadLen) + + pkt := c.GetPacket() + return uint32(header.TCP(header.IPv4(pkt).Payload()).WindowSize()) << c.RcvdWindowScale + } + + // Read the advertised receive window with the ACK for payload. + rcvWnd := getIncomingRcvWnd() + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), rcvWnd-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Read the data so that the subsequent ACK from the endpoint + // grows the right edge of the window. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("got Read(nil) = %s", err) + } + + // Check if we have received max uint16 as our advertised + // scaled window now after a read above. + maxRcv := uint32(math.MaxUint16 << c.RcvdWindowScale) + if got, want := getIncomingRcvWnd(), maxRcv; got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } + + // Check if the subsequent ACK to our send has not grown the right edge of + // the window. + if got, want := getIncomingRcvWnd(), maxRcv-uint32(len(payload)); got != want { + t.Fatalf("got incomingRcvwnd %d want %d", got, want) + } +} + func TestNoWindowShrinking(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -4148,7 +4225,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. peekBuf := make([]byte, 10) - n, _, err := c.EP.Peek([][]byte{peekBuf}) + n, err := c.EP.Peek([][]byte{peekBuf}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4174,7 +4251,7 @@ func TestReadAfterClosedState(t *testing.T) { t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) } - if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) } } @@ -4429,7 +4506,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -4439,15 +4516,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 010a23e45..ee55f030c 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -635,11 +635,11 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(header.TCPMinimumSize + len(payload)), - NextHeader: uint8(tcp.ProtocolNumber), - HopLimit: 65, - SrcAddr: src, - DstAddr: dst, + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + TransportProtocol: tcp.ProtocolNumber, + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, }) // Initialize the TCP header. diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 5043e7aa5..9b9e4deb0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -30,10 +30,11 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry - senderAddress tcpip.FullAddress - packetInfo tcpip.IPPacketInfo - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - timestamp int64 + senderAddress tcpip.FullAddress + destinationAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 // tos stores either the receiveTOS or receiveTClass value. tos uint8 } @@ -108,7 +109,6 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID portFlags ports.Flags - bindToDevice tcpip.NICID lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -143,9 +143,6 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -228,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error { return err } +// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError. +func (e *endpoint) UpdateLastError(err *tcpip.Error) { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() +} + // Abort implements stack.TransportEndpoint.Abort. func (e *endpoint) Abort() { e.Close() @@ -323,6 +327,10 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } + if e.ops.GetReceiveOriginalDstAddress() { + cm.HasOriginalDstAddress = true + cm.OriginalDstAddress = p.destinationAddress + } return p.data.ToView(), cm, nil } @@ -509,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. + so := e.SocketOptions() + if so.GetRecvError() { + so.QueueLocalErr( + tcpip.ErrMessageTooLong, + route.NetProto, + header.UDPMaximumPacketSize, + tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + Port: dstPort, + }, + v, + ) + } return 0, nil, tcpip.ErrMessageTooLong } @@ -545,8 +567,8 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } // Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil +func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { + return 0, nil } // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. @@ -636,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -752,22 +778,8 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil - - case *tcpip.LingerOption: - e.mu.Lock() - e.linger = *v - e.mu.Unlock() } return nil } @@ -841,16 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - - case *tcpip.LingerOption: - e.mu.RLock() - *o = e.linger - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1004,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer r.Release() id := stack.TransportEndpointID{ LocalAddress: e.ID.LocalAddress, @@ -1032,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { id, btd, err := e.registerWithStack(nicID, netProtos, id) if err != nil { + r.Release() return err } @@ -1042,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.ID = id e.boundBindToDevice = btd - e.route = r.Clone() + e.route = r e.dstPort = addr.Port e.RegisterNICID = nicID e.effectiveNetProtos = netProtos @@ -1100,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { @@ -1311,6 +1314,11 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, + destinationAddress: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: header.UDP(hdr).DestinationPort(), + }, } packet.data = pkt.Data e.rcvList.PushBack(packet) @@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB } } +func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) { + // Update last error first. + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + // Update the error queue if IP_RECVERR is enabled. + if e.SocketOptions().GetRecvError() { + // Linux passes the payload without the UDP header. + var payload []byte + udp := header.UDP(pkt.Data.ToView()) + if len(udp) >= header.UDPMinimumSize { + payload = udp.Payload() + } + + e.SocketOptions().QueueErr(&tcpip.SockError{ + Err: err, + ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber), + ErrType: errType, + ErrCode: errCode, + ErrInfo: extra, + Payload: payload, + Dst: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.RemoteAddress, + Port: id.RemotePort, + }, + Offender: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: id.LocalAddress, + Port: id.LocalPort, + }, + NetProto: pkt.NetworkProtocolNumber, + }) + } + + // Notify of the error. + e.waiterQueue.Notify(waiter.EventErr) +} + // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { if e.EndpointState() == StateConnected { - e.lastErrorMu.Lock() - e.lastError = tcpip.ErrConnectionRefused - e.lastErrorMu.Unlock() - - e.waiterQueue.Notify(waiter.EventErr) + var errType byte + var errCode byte + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + errType = byte(header.ICMPv4DstUnreachable) + errCode = byte(header.ICMPv4PortUnreachable) + case header.IPv6ProtocolNumber: + errType = byte(header.ICMPv6DstUnreachable) + errCode = byte(header.ICMPv6PortUnreachable) + default: + panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber)) + } + e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt) return } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index e384f52dd..8429f34b4 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -452,12 +452,12 @@ func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(header.UDPMinimumSize + len(payload)), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. @@ -554,7 +554,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -564,15 +564,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) @@ -1427,6 +1425,93 @@ func TestReadIPPacketInfo(t *testing.T) { } } +func TestReadRecvOriginalDstAddr(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedOriginalDstAddr tcpip.FullAddress + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackAddr, stackPort}, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastAddr, stackPort}, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, broadcastAddr, stackPort}, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedOriginalDstAddr: tcpip.FullAddress{1, stackV6Addr, stackPort}, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedOriginalDstAddr: tcpip.FullAddress{1, multicastV6Addr, stackPort}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%#v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) + } + } + + c.ep.SocketOptions().SetReceiveOriginalDstAddress(true) + + testRead(c, test.flow, checker.ReceiveOriginalDstAddr(test.expectedOriginalDstAddr)) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1994,12 +2079,12 @@ func TestShortHeader(t *testing.T) { // Initialize the IP header. ip := header.IPv6(buf) ip.Encode(&header.IPv6Fields{ - TrafficClass: testTOS, - PayloadLength: uint16(udpSize), - NextHeader: uint8(udp.ProtocolNumber), - HopLimit: 65, - SrcAddr: h.srcAddr.Addr, - DstAddr: h.dstAddr.Addr, + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + TransportProtocol: udp.ProtocolNumber, + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, }) // Initialize the UDP header. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 2bf0a22ff..7b5fcef9c 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -55,11 +55,8 @@ type Container struct { copyErr error cleanups []func() - // Profiles are profiles added to this container. They contain methods - // that are run after Creation, Start, and Cleanup of this Container, along - // a handle to restart the profile. Generally, tests/benchmarks using - // profiles need to run as root. - profiles []Profile + // profile is the profiling hook associated with this container. + profile *profile } // RunOpts are options for running a container. @@ -105,22 +102,7 @@ type RunOpts struct { Links []string } -// MakeContainer sets up the struct for a Docker container. -// -// Names of containers will be unique. -// Containers will check flags for profiling requests. -func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { - c := MakeNativeContainer(ctx, logger) - c.runtime = *runtime - if p := MakePprofFromFlags(c); p != nil { - c.AddProfile(p) - } - return c -} - -// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native -// containers aren't profiled. -func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { +func makeContainer(ctx context.Context, logger testutil.Logger, runtime string) *Container { // Slashes are not allowed in container names. name := testutil.RandomID(logger.Name()) name = strings.ReplaceAll(name, "/", "-") @@ -132,24 +114,29 @@ func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container return &Container{ logger: logger, Name: name, - runtime: "", + runtime: runtime, client: client, } } -// AddProfile adds a profile to this container. -func (c *Container) AddProfile(p Profile) { - c.profiles = append(c.profiles, p) +// MakeContainer constructs a suitable Container object. +// +// The runtime used is determined by the runtime flag. +// +// Containers will check flags for profiling requests. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := makeContainer(ctx, logger, *runtime) + c.profileInit() + return c } -// RestartProfiles calls Restart on all profiles for this container. -func (c *Container) RestartProfiles() error { - for _, profile := range c.profiles { - if err := profile.Restart(c); err != nil { - return err - } - } - return nil +// MakeNativeContainer constructs a suitable Container object. +// +// The runtime used will be the system default. +// +// Native containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { + return makeContainer(ctx, logger, "" /*runtime*/) } // Spawn is analogous to 'docker run -d'. @@ -206,6 +193,8 @@ func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, return "", err } + c.stopProfiling() + return c.Logs(ctx) } @@ -236,11 +225,6 @@ func (c *Container) create(ctx context.Context, conf *container.Config, hostconf return err } c.id = cont.ID - for _, profile := range c.profiles { - if err := profile.OnCreate(c); err != nil { - return fmt.Errorf("OnCreate method failed with: %v", err) - } - } return nil } @@ -286,11 +270,13 @@ func (c *Container) Start(ctx context.Context) error { if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { return fmt.Errorf("ContainerStart failed: %v", err) } - for _, profile := range c.profiles { - if err := profile.OnStart(c); err != nil { - return fmt.Errorf("OnStart method failed: %v", err) + + if c.profile != nil { + if err := c.profile.Start(c); err != nil { + c.logger.Logf("profile.Start failed: %v", err) } } + return nil } @@ -499,8 +485,18 @@ func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, t } } +// stopProfiling stops profiling. +func (c *Container) stopProfiling() { + if c.profile != nil { + if err := c.profile.Stop(c); err != nil { + c.logger.Logf("profile.Stop failed: %v", err) + } + } +} + // Kill kills the container. func (c *Container) Kill(ctx context.Context) error { + c.stopProfiling() return c.client.ContainerKill(ctx, c.id, "") } @@ -517,14 +513,6 @@ func (c *Container) Remove(ctx context.Context) error { // CleanUp kills and deletes the container (best effort). func (c *Container) CleanUp(ctx context.Context) { - // Execute profile cleanups before the container goes down. - for _, profile := range c.profiles { - profile.OnCleanUp(c) - } - - // Forget profiles. - c.profiles = nil - // Execute all cleanups. We execute cleanups here to close any // open connections to the container before closing. Open connections // can cause Kill and Remove to hang. @@ -538,10 +526,12 @@ func (c *Container) CleanUp(ctx context.Context) { // Just log; can't do anything here. c.logger.Logf("error killing container %q: %v", c.Name, err) } + // Remove the image. if err := c.Remove(ctx); err != nil { c.logger.Logf("error removing container %q: %v", c.Name, err) } + // Forget all mounts. c.mounts = nil } diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 7027df1a5..a40005799 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -49,15 +49,11 @@ var ( // pprofBaseDir allows the user to change the directory to which profiles are // written. By default, profiles will appear under: // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. - pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") - - // duration is the max duration `runsc debug` will run and capture profiles. - // If the container's clean up method is called prior to duration, the - // profiling process will be killed. - duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + pprofDuration = flag.Duration("pprof-duration", time.Hour, "profiling duration (automatically stopped at container exit)") // The below flags enable each type of profile. Multiple profiles can be - // enabled for each run. + // enabled for each run. The profile will be collected from the start. pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 55f9496cd..f1103eb6e 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -17,72 +17,57 @@ package dockerutil import ( "context" "fmt" - "io" "os" "os/exec" "path/filepath" + "syscall" "time" ) -// Profile represents profile-like operations on a container, -// such as running perf or pprof. It is meant to be added to containers -// such that the container type calls the Profile during its lifecycle. -type Profile interface { - // OnCreate is called just after the container is created when the container - // has a valid ID (e.g. c.ID()). - OnCreate(c *Container) error - - // OnStart is called just after the container is started when the container - // has a valid Pid (e.g. c.SandboxPid()). - OnStart(c *Container) error - - // Restart restarts the Profile on request. - Restart(c *Container) error - - // OnCleanUp is called during the container's cleanup method. - // Cleanups should just log errors if they have them. - OnCleanUp(c *Container) error -} - -// Pprof is for running profiles with 'runsc debug'. Pprof workloads -// should be run as root and ONLY against runsc sandboxes. The runtime -// should have --profile set as an option in /etc/docker/daemon.json in -// order for profiling to work with Pprof. -type Pprof struct { - BasePath string // path to put profiles - BlockProfile bool - CPUProfile bool - HeapProfile bool - MutexProfile bool - Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. - shouldRun bool - cmd *exec.Cmd - stdout io.ReadCloser - stderr io.ReadCloser +// profile represents profile-like operations on a container. +// +// It is meant to be added to containers such that the container type calls +// the profile during its lifecycle. Standard implementations are below. + +// profile is for running profiles with 'runsc debug'. +type profile struct { + BasePath string + Types []string + Duration time.Duration + cmd *exec.Cmd } -// MakePprofFromFlags makes a Pprof profile from flags. -func MakePprofFromFlags(c *Container) *Pprof { - if !(*pprofBlock || *pprofCPU || *pprofHeap || *pprofMutex) { - return nil +// profileInit initializes a profile object, if required. +func (c *Container) profileInit() { + if !*pprofBlock && !*pprofCPU && !*pprofMutex && !*pprofHeap { + return // Nothing to do. + } + c.profile = &profile{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + Duration: *pprofDuration, } - return &Pprof{ - BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), - BlockProfile: *pprofBlock, - CPUProfile: *pprofCPU, - HeapProfile: *pprofHeap, - MutexProfile: *pprofMutex, - Duration: *duration, + if *pprofCPU { + c.profile.Types = append(c.profile.Types, "cpu") + } + if *pprofHeap { + c.profile.Types = append(c.profile.Types, "heap") + } + if *pprofMutex { + c.profile.Types = append(c.profile.Types, "mutex") + } + if *pprofBlock { + c.profile.Types = append(c.profile.Types, "block") } } -// OnCreate implements Profile.OnCreate. -func (p *Pprof) OnCreate(c *Container) error { - return os.MkdirAll(p.BasePath, 0755) -} +// createProcess creates the collection process. +func (p *profile) createProcess(c *Container) error { + // Ensure our directory exists. + if err := os.MkdirAll(p.BasePath, 0755); err != nil { + return err + } -// OnStart implements Profile.OnStart. -func (p *Pprof) OnStart(c *Container) error { + // Find the runtime to invoke. path, err := RuntimePath() if err != nil { return fmt.Errorf("failed to get runtime path: %v", err) @@ -90,58 +75,66 @@ func (p *Pprof) OnStart(c *Container) error { // The root directory of this container's runtime. root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) - // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + + // Format is `runsc --root=rootdir debug --profile-*=file --duration=24h containerID`. args := []string{root, "debug"} - args = append(args, p.makeProfileArgs(c)...) + for _, profileArg := range p.Types { + outputPath := filepath.Join(p.BasePath, fmt.Sprintf("%s.pprof", profileArg)) + args = append(args, fmt.Sprintf("--profile-%s=%s", profileArg, outputPath)) + } + args = append(args, fmt.Sprintf("--duration=%s", p.Duration)) // Or until container exits. args = append(args, c.ID()) // Best effort wait until container is running. for now := time.Now(); time.Since(now) < 5*time.Second; { if status, err := c.Status(context.Background()); err != nil { return fmt.Errorf("failed to get status with: %v", err) - } else if status.Running { break } - time.Sleep(500 * time.Millisecond) + time.Sleep(100 * time.Millisecond) } p.cmd = exec.Command(path, args...) + p.cmd.Stderr = os.Stderr // Pass through errors. if err := p.cmd.Start(); err != nil { - return fmt.Errorf("process failed: %v", err) + return fmt.Errorf("start process failed: %v", err) } + return nil } -// Restart implements Profile.Restart. -func (p *Pprof) Restart(c *Container) error { - p.OnCleanUp(c) - return p.OnStart(c) +// killProcess kills the process, if running. +// +// Precondition: mu must be held. +func (p *profile) killProcess() error { + if p.cmd != nil && p.cmd.Process != nil { + return p.cmd.Process.Signal(syscall.SIGTERM) + } + return nil } -// OnCleanUp implements Profile.OnCleanup -func (p *Pprof) OnCleanUp(c *Container) error { +// waitProcess waits for the process, if running. +// +// Precondition: mu must be held. +func (p *profile) waitProcess() error { defer func() { p.cmd = nil }() - if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { - return p.cmd.Process.Kill() + if p.cmd != nil { + return p.cmd.Wait() } return nil } -// makeProfileArgs turns Pprof fields into runsc debug flags. -func (p *Pprof) makeProfileArgs(c *Container) []string { - var ret []string - if p.BlockProfile { - ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) - } - if p.CPUProfile { - ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) - } - if p.HeapProfile { - ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) - } - if p.MutexProfile { - ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) +// Start is called when profiling is started. +func (p *profile) Start(c *Container) error { + return p.createProcess(c) +} + +// Stop is called when profiling is started. +func (p *profile) Stop(c *Container) error { + killErr := p.killProcess() + waitErr := p.waitProcess() + if waitErr != nil && killErr != nil { + return killErr } - ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) - return ret + return waitErr // Ignore okay wait, err kill. } diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go index 8c4ffe483..4fe9ce15c 100644 --- a/pkg/test/dockerutil/profile_test.go +++ b/pkg/test/dockerutil/profile_test.go @@ -17,6 +17,7 @@ package dockerutil import ( "context" "fmt" + "io/ioutil" "os" "path/filepath" "testing" @@ -25,52 +26,60 @@ import ( type testCase struct { name string - pprof Pprof + profile profile expectedFiles []string } -func TestPprof(t *testing.T) { +func TestProfile(t *testing.T) { // Basepath and expected file names for each type of profile. - basePath := "/tmp/test/profile" + tmpDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("unable to create temporary directory: %v", err) + } + defer os.RemoveAll(tmpDir) + + // All expected names. + basePath := tmpDir block := "block.pprof" cpu := "cpu.pprof" - goprofle := "go.pprof" heap := "heap.pprof" mutex := "mutex.pprof" testCases := []testCase{ { - name: "Cpu", - pprof: Pprof{ - BasePath: basePath, - CPUProfile: true, - Duration: 2 * time.Second, + name: "One", + profile: profile{ + BasePath: basePath, + Types: []string{"cpu"}, + Duration: 2 * time.Second, }, expectedFiles: []string{cpu}, }, { name: "All", - pprof: Pprof{ - BasePath: basePath, - BlockProfile: true, - CPUProfile: true, - HeapProfile: true, - MutexProfile: true, - Duration: 2 * time.Second, + profile: profile{ + BasePath: basePath, + Types: []string{"block", "cpu", "heap", "mutex"}, + Duration: 2 * time.Second, }, - expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + expectedFiles: []string{block, cpu, heap, mutex}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. - tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) - c.AddProfile(&tc.pprof) + localProfile := tc.profile // Copy it. + localProfile.BasePath = filepath.Join(localProfile.BasePath, tc.name) + + // Set directly on the container, to avoid flags. + c.profile = &localProfile func() { defer c.CleanUp(ctx) + // Start a container. if err := c.Spawn(ctx, RunOpts{ Image: "basic/alpine", @@ -83,24 +92,24 @@ func TestPprof(t *testing.T) { } // End early if the expected files exist and have data. - for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { - if err := checkFiles(tc); err == nil { + for start := time.Now(); time.Since(start) < localProfile.Duration; time.Sleep(100 * time.Millisecond) { + if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err == nil { break } } }() // Check all expected files exist and have data. - if err := checkFiles(tc); err != nil { + if err := checkFiles(localProfile.BasePath, tc.expectedFiles); err != nil { t.Fatalf(err.Error()) } }) } } -func checkFiles(tc testCase) error { - for _, file := range tc.expectedFiles { - stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) +func checkFiles(basePath string, expectedFiles []string) error { + for _, file := range expectedFiles { + stat, err := os.Stat(filepath.Join(basePath, file)) if err != nil { return fmt.Errorf("stat failed with: %v", err) } else if stat.Size() < 1 { diff --git a/pkg/urpc/urpc.go b/pkg/urpc/urpc.go index 13b2ea314..dfd23032c 100644 --- a/pkg/urpc/urpc.go +++ b/pkg/urpc/urpc.go @@ -283,12 +283,10 @@ func (s *Server) handleOne(client *unet.Socket) error { // Client is dead. return err } + if s.afterRPCCallback != nil { + defer s.afterRPCCallback() + } - defer func() { - if s.afterRPCCallback != nil { - s.afterRPCCallback() - } - }() // Explicitly close all these files after the call. // // This is also explicitly a reference to the files after the call, diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 9b1e7a085..79db8895b 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -167,7 +167,7 @@ func (rw *IOReadWriter) Read(dst []byte) (int, error) { return n, err } -// Writer implements io.Writer.Write. +// Write implements io.Writer.Write. func (rw *IOReadWriter) Write(src []byte) (int, error) { n, err := rw.IO.CopyOut(rw.Ctx, rw.Addr, src, rw.Opts) end, ok := rw.Addr.AddLength(uint64(n)) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 8c73dc5dc..67307ab3c 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fd", + "//pkg/flipcall", "//pkg/fspath", "//pkg/log", "//pkg/memutil", diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 7076ae2e2..a3a76b609 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -53,7 +53,7 @@ type compatEmitter struct { func newCompatEmitter(logFD int) (*compatEmitter, error) { nameMap, ok := getSyscallNameMap() if !ok { - return nil, fmt.Errorf("Linux syscall table not found") + return nil, fmt.Errorf("syscall table not found") } c := &compatEmitter{ diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 865126ac5..9008e1282 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -104,13 +104,11 @@ const ( // Profiling related commands (see pprof.go for more details). const ( - StartCPUProfile = "Profile.StartCPUProfile" - StopCPUProfile = "Profile.StopCPUProfile" - HeapProfile = "Profile.HeapProfile" - BlockProfile = "Profile.BlockProfile" - MutexProfile = "Profile.MutexProfile" - StartTrace = "Profile.StartTrace" - StopTrace = "Profile.StopTrace" + CPUProfile = "Profile.CPU" + HeapProfile = "Profile.Heap" + BlockProfile = "Profile.Block" + MutexProfile = "Profile.Mutex" + Trace = "Profile.Trace" ) // Logging related commands (see logging.go for more details). @@ -132,8 +130,13 @@ type controller struct { // manager holds the containerManager methods. manager *containerManager - // pprop holds the profile instance if enabled. It may be nil. + // pprof holds the profile instance if enabled. It may be nil. pprof *control.Profile + + // stopProfiling has the callback to stop profiling calls. As + // this may be executed only once at most, it will be set to nil + // after it is executed for the first time. + stopProfiling func() } // newController creates a new controller. The caller must call @@ -164,7 +167,7 @@ func newController(fd int, l *Loader) (*controller, error) { ctrl.srv.Register(&control.Logging{}) if l.root.conf.ProfileEnable { - ctrl.pprof = &control.Profile{Kernel: l.k} + ctrl.pprof, ctrl.stopProfiling = control.NewProfile(l.k) ctrl.srv.Register(ctrl.pprof) } @@ -172,10 +175,9 @@ func newController(fd int, l *Loader) (*controller, error) { } func (c *controller) stop() { - if c.pprof != nil { - // These are noop if there is nothing being profiled. - _ = c.pprof.StopCPUProfile(nil, nil) - _ = c.pprof.StopTrace(nil, nil) + if c.stopProfiling != nil { + c.stopProfiling() + c.stopProfiling = nil } } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index a7c4ebb0c..eacd73531 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -343,6 +343,21 @@ func hostInetFilters() seccomp.SyscallRules { }, { seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_PKTINFO), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVERR), + }, + { + seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), seccomp.EqualTo(syscall.IPV6_TCLASS), }, @@ -354,10 +369,20 @@ func hostInetFilters() seccomp.SyscallRules { { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVERR), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), seccomp.EqualTo(syscall.IPV6_V6ONLY), }, { seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR), + }, + { + seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_SOCKET), seccomp.EqualTo(syscall.SO_ERROR), }, @@ -393,6 +418,11 @@ func hostInetFilters() seccomp.SyscallRules { }, { seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_TIMESTAMP), + }, + { + seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_TCP), seccomp.EqualTo(syscall.TCP_NODELAY), }, @@ -401,6 +431,11 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.EqualTo(syscall.SOL_TCP), seccomp.EqualTo(syscall.TCP_INFO), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(linux.TCP_INQ), + }, }, syscall.SYS_IOCTL: []seccomp.Rule{ { @@ -421,29 +456,29 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_SETSOCKOPT: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.EqualTo(syscall.SO_RCVBUF), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_RCVBUF), + seccomp.EqualTo(syscall.SO_REUSEADDR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_REUSEADDR), + seccomp.EqualTo(syscall.SO_TIMESTAMP), seccomp.MatchAny{}, seccomp.EqualTo(4), }, @@ -456,6 +491,13 @@ func hostInetFilters() seccomp.SyscallRules { }, { seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(linux.TCP_INQ), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IP), seccomp.EqualTo(syscall.IP_TOS), seccomp.MatchAny{}, @@ -470,6 +512,27 @@ func hostInetFilters() seccomp.SyscallRules { }, { seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_PKTINFO), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVERR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, seccomp.EqualTo(syscall.SOL_IPV6), seccomp.EqualTo(syscall.IPV6_TCLASS), seccomp.MatchAny{}, @@ -482,6 +545,27 @@ func hostInetFilters() seccomp.SyscallRules { seccomp.MatchAny{}, seccomp.EqualTo(4), }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVERR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, + { + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), + }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 3df013d34..f41d6c665 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -294,7 +294,7 @@ func New(args Args) (*Loader, error) { if args.TotalMem > 0 { // Adjust the total memory returned by the Sentry so that applications that // use /proc/meminfo can make allocations based on this limit. - usage.MinimumTotalMemoryBytes = args.TotalMem + usage.MaximumTotalMemoryBytes = args.TotalMem log.Infof("Setting total memory to %.2f GB", float64(args.TotalMem)/(1<<30)) } @@ -598,7 +598,6 @@ func (l *Loader) run() error { if err != nil { return err } - } ep.tg = l.k.GlobalInit() @@ -1045,9 +1044,10 @@ func (l *Loader) WaitExit() kernel.ExitStatus { // Wait for container. l.k.WaitExited() - // Cleanup + // Stop the control server. l.ctrl.stop() + // Check all references. refs.OnExit() return l.k.GlobalInit().ExitStatus() diff --git a/runsc/cli/main.go b/runsc/cli/main.go index bca015db5..6c3bf4d21 100644 --- a/runsc/cli/main.go +++ b/runsc/cli/main.go @@ -22,6 +22,7 @@ import ( "io/ioutil" "os" "os/signal" + "runtime" "syscall" "time" @@ -82,6 +83,7 @@ func Main(version string) { subcommands.Register(new(cmd.Spec), "") subcommands.Register(new(cmd.State), "") subcommands.Register(new(cmd.Start), "") + subcommands.Register(new(cmd.Symbolize), "") subcommands.Register(new(cmd.Wait), "") // Register internal commands with the internal group name. This causes @@ -207,6 +209,8 @@ func Main(version string) { log.Infof("***************************") log.Infof("Args: %s", os.Args) log.Infof("Version %s", version) + log.Infof("GOOS: %s", runtime.GOOS) + log.Infof("GOARCH: %s", runtime.GOARCH) log.Infof("PID: %d", os.Getpid()) log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid()) log.Infof("Configuration:") diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 2556f6d9e..19520d7ab 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -32,6 +32,7 @@ go_library( "start.go", "state.go", "statefile.go", + "symbolize.go", "syscalls.go", "wait.go", ], @@ -39,6 +40,7 @@ go_library( "//runsc:__subpackages__", ], deps = [ + "//pkg/coverage", "//pkg/log", "//pkg/p9", "//pkg/sentry/control", diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index c0bc8f064..124198239 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -75,7 +75,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) - cont, err := container.LoadAndCheck(conf.RootDir, id) + cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 609e8231c..843dce01d 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -19,6 +19,7 @@ import ( "os" "strconv" "strings" + "sync" "syscall" "time" @@ -70,10 +71,10 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.StringVar(&d.profileCPU, "profile-cpu", "", "writes CPU profile to the given file.") f.StringVar(&d.profileBlock, "profile-block", "", "writes block profile to the given file.") f.StringVar(&d.profileMutex, "profile-mutex", "", "writes mutex profile to the given file.") - f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles") + f.DurationVar(&d.duration, "duration", time.Second, "amount of time to wait for CPU and trace profiles.") f.StringVar(&d.trace, "trace", "", "writes an execution trace to the given file.") f.IntVar(&d.signal, "signal", -1, "sends signal to the sandbox") - f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all`) + f.StringVar(&d.strace, "strace", "", `A comma separated list of syscalls to trace. "all" enables all traces, "off" disables all.`) f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).") f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.") f.BoolVar(&d.ps, "ps", false, "lists processes") @@ -90,8 +91,10 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) f.Usage() return subcommands.ExitUsageError } + id := f.Arg(0) + var err error - c, err = container.LoadAndCheck(conf.RootDir, f.Arg(0)) + c, err = container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { return Errorf("loading container %q: %v", f.Arg(0), err) } @@ -106,9 +109,10 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return Errorf("listing containers: %v", err) } for _, id := range ids { - candidate, err := container.LoadAndCheck(conf.RootDir, id) + candidate, err := container.Load(conf.RootDir, id, container.LoadOpts{Exact: true, SkipCheck: true}) if err != nil { - return Errorf("loading container %q: %v", id, err) + log.Warningf("Skipping container %q: %v", id, err) + continue } if candidate.SandboxPid() == d.pid { c = candidate @@ -120,11 +124,12 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } } - if c.Sandbox == nil || !c.Sandbox.IsRunning() { + if !c.IsSandboxRunning() { return Errorf("container sandbox is not running") } log.Infof("Found sandbox %q, PID: %d", c.Sandbox.ID, c.Sandbox.Pid) + // Perform synchronous actions. if d.signal > 0 { log.Infof("Sending signal %d to process: %d", d.signal, c.Sandbox.Pid) if err := syscall.Kill(c.Sandbox.Pid, syscall.Signal(d.signal)); err != nil { @@ -140,80 +145,15 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) log.Infof(" *** Stack dump ***\n%s", stacks) } if d.profileHeap != "" { - f, err := os.Create(d.profileHeap) + f, err := os.OpenFile(d.profileHeap, os.O_CREATE|os.O_TRUNC, 0644) if err != nil { - return Errorf(err.Error()) + return Errorf("error opening heap profile output: %v", err) } defer f.Close() - if err := c.Sandbox.HeapProfile(f); err != nil { - return Errorf(err.Error()) - } - log.Infof("Heap profile written to %q", d.profileHeap) - } - if d.profileBlock != "" { - f, err := os.Create(d.profileBlock) - if err != nil { - return Errorf(err.Error()) - } - defer f.Close() - - if err := c.Sandbox.BlockProfile(f); err != nil { - return Errorf(err.Error()) - } - log.Infof("Block profile written to %q", d.profileBlock) - } - if d.profileMutex != "" { - f, err := os.Create(d.profileMutex) - if err != nil { - return Errorf(err.Error()) - } - defer f.Close() - - if err := c.Sandbox.MutexProfile(f); err != nil { - return Errorf(err.Error()) + return Errorf("error collecting heap profile: %v", err) } - log.Infof("Mutex profile written to %q", d.profileMutex) } - - delay := false - if d.profileCPU != "" { - delay = true - f, err := os.Create(d.profileCPU) - if err != nil { - return Errorf(err.Error()) - } - defer func() { - f.Close() - if err := c.Sandbox.StopCPUProfile(); err != nil { - Fatalf(err.Error()) - } - log.Infof("CPU profile written to %q", d.profileCPU) - }() - if err := c.Sandbox.StartCPUProfile(f); err != nil { - return Errorf(err.Error()) - } - log.Infof("CPU profile started for %v, writing to %q", d.duration, d.profileCPU) - } - if d.trace != "" { - delay = true - f, err := os.Create(d.trace) - if err != nil { - return Errorf(err.Error()) - } - defer func() { - f.Close() - if err := c.Sandbox.StopTrace(); err != nil { - Fatalf(err.Error()) - } - log.Infof("Trace written to %q", d.trace) - }() - if err := c.Sandbox.StartTrace(f); err != nil { - return Errorf(err.Error()) - } - log.Infof("Tracing started for %v, writing to %q", d.duration, d.trace) - } - if d.strace != "" || len(d.logLevel) != 0 || len(d.logPackets) != 0 { args := control.LoggingArgs{} switch strings.ToLower(d.strace) { @@ -282,8 +222,98 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) log.Infof(o) } - if delay { - time.Sleep(d.duration) + // Open profiling files. + var ( + cpuFile *os.File + traceFile *os.File + blockFile *os.File + mutexFile *os.File + ) + if d.profileCPU != "" { + f, err := os.OpenFile(d.profileCPU, os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return Errorf("error opening cpu profile output: %v", err) + } + defer f.Close() + cpuFile = f + } + if d.trace != "" { + f, err := os.OpenFile(d.trace, os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return Errorf("error opening trace profile output: %v", err) + } + traceFile = f + } + if d.profileBlock != "" { + f, err := os.OpenFile(d.profileBlock, os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return Errorf("error opening blocking profile output: %v", err) + } + defer f.Close() + blockFile = f + } + if d.profileMutex != "" { + f, err := os.OpenFile(d.profileMutex, os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return Errorf("error opening mutex profile output: %v", err) + } + defer f.Close() + mutexFile = f + } + + // Collect profiles. + var ( + wg sync.WaitGroup + cpuErr error + traceErr error + blockErr error + mutexErr error + ) + if cpuFile != nil { + wg.Add(1) + go func() { + defer wg.Done() + cpuErr = c.Sandbox.CPUProfile(cpuFile, d.duration) + }() + } + if traceFile != nil { + wg.Add(1) + go func() { + defer wg.Done() + traceErr = c.Sandbox.Trace(traceFile, d.duration) + }() + } + if blockFile != nil { + wg.Add(1) + go func() { + defer wg.Done() + blockErr = c.Sandbox.BlockProfile(blockFile, d.duration) + }() + } + if mutexFile != nil { + wg.Add(1) + go func() { + defer wg.Done() + mutexErr = c.Sandbox.MutexProfile(mutexFile, d.duration) + }() + } + + wg.Wait() + errorCount := 0 + if cpuErr != nil { + log.Infof("error collecting cpu profile: %v", cpuErr) + } + if traceErr != nil { + log.Infof("error collecting trace profile: %v", traceErr) + } + if blockErr != nil { + log.Infof("error collecting block profile: %v", blockErr) + } + if mutexErr != nil { + log.Infof("error collecting mutex profile: %v", mutexErr) + } + if errorCount > 0 { + return subcommands.ExitFailure } return subcommands.ExitSuccess diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go index a25637265..a750be131 100644 --- a/runsc/cmd/delete.go +++ b/runsc/cmd/delete.go @@ -68,7 +68,7 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} func (d *Delete) execute(ids []string, conf *config.Config) error { for _, id := range ids { - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { if os.IsNotExist(err) && d.force { log.Warningf("couldn't find container %q: %v", id, err) diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 3836b7b4e..75b0aac8d 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -74,7 +74,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index eafd6285c..8558d34ae 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -112,7 +112,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } waitStatus := args[1].(*syscall.WaitStatus) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index fe69e2a08..aecf0b7ab 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -69,7 +69,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("it is invalid to specify both --all and --pid") } - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go index 6907eb16a..9f9a47bd8 100644 --- a/runsc/cmd/list.go +++ b/runsc/cmd/list.go @@ -24,6 +24,7 @@ import ( "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -71,7 +72,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if l.quiet { for _, id := range ids { - fmt.Println(id) + fmt.Println(id.ContainerID) } return subcommands.ExitSuccess } @@ -79,9 +80,10 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Collect the containers. var containers []*container.Container for _, id := range ids { - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, id, container.LoadOpts{Exact: true}) if err != nil { - Fatalf("loading container %q: %v", id, err) + log.Warningf("Skipping container %q: %v", id, err) + continue } containers = append(containers, c) } diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go index fe7d4e257..15ef7b577 100644 --- a/runsc/cmd/pause.go +++ b/runsc/cmd/pause.go @@ -55,7 +55,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.LoadAndCheck(conf.RootDir, id) + cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go index 18d7a1436..04e3e0bdd 100644 --- a/runsc/cmd/ps.go +++ b/runsc/cmd/ps.go @@ -60,7 +60,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go index a00928204..856469252 100644 --- a/runsc/cmd/resume.go +++ b/runsc/cmd/resume.go @@ -56,7 +56,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.LoadAndCheck(conf.RootDir, id) + cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index f6499cc44..964a65064 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -55,7 +55,7 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go index d8a70dd7f..1f7913d5a 100644 --- a/runsc/cmd/state.go +++ b/runsc/cmd/state.go @@ -57,7 +57,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/symbolize.go b/runsc/cmd/symbolize.go new file mode 100644 index 000000000..fc0c69358 --- /dev/null +++ b/runsc/cmd/symbolize.go @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "bufio" + "context" + "os" + "strconv" + "strings" + + "github.com/google/subcommands" + "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/runsc/flag" +) + +// Symbolize implements subcommands.Command for the "symbolize" command. +type Symbolize struct { + dumpAll bool +} + +// Name implements subcommands.Command.Name. +func (*Symbolize) Name() string { + return "symbolize" +} + +// Synopsis implements subcommands.Command.Synopsis. +func (*Symbolize) Synopsis() string { + return "Convert synthetic instruction pointers from kcov into positions in the runsc source code. Only used when Go coverage is enabled." +} + +// Usage implements subcommands.Command.Usage. +func (*Symbolize) Usage() string { + return `symbolize - converts synthetic instruction pointers into positions in the runsc source code. + +This command takes instruction pointers from stdin and converts them into their +corresponding file names and line/column numbers in the runsc source code. The +inputs are not interpreted as actual addresses, but as synthetic values that are +exposed through /sys/kernel/debug/kcov. One can extract coverage information +from kcov and translate those values into locations in the source code by +running symbolize on the same runsc binary. +` +} + +// SetFlags implements subcommands.Command.SetFlags. +func (c *Symbolize) SetFlags(f *flag.FlagSet) { + f.BoolVar(&c.dumpAll, "all", false, "dump information on all coverage blocks along with their synthetic PCs") +} + +// Execute implements subcommands.Command.Execute. +func (c *Symbolize) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if f.NArg() != 0 { + f.Usage() + return subcommands.ExitUsageError + } + if !coverage.KcovAvailable() { + return Errorf("symbolize can only be used when coverage is available.") + } + coverage.InitCoverageData() + + if c.dumpAll { + coverage.WriteAllBlocks(os.Stdout) + return subcommands.ExitSuccess + } + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + // Input is always base 16, but may or may not have a leading "0x". + str := strings.TrimPrefix(scanner.Text(), "0x") + pc, err := strconv.ParseUint(str, 16 /* base */, 64 /* bitSize */) + if err != nil { + return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err) + } + if err := coverage.Symbolize(os.Stdout, pc); err != nil { + return Errorf("Failed to symbolize \"%s\": %v", scanner.Text(), err) + } + } + return subcommands.ExitSuccess +} diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index c1d6aeae2..5d55422c7 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -72,7 +72,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.LoadAndCheck(conf.RootDir, id) + c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/config/config.go b/runsc/config/config.go index b02d8e2e1..e9fd7708f 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -131,7 +131,7 @@ type Config struct { NumNetworkChannels int `flag:"num-network-channels"` // Rootless allows the sandbox to be started with a user that is not root. - // Defense is depth measures are weaker with rootless. Specifically, the + // Defense in depth measures are weaker in rootless mode. Specifically, the // sandbox and Gofer process run as root inside a user namespace with root // mapped to the caller's user. Rootless bool `flag:"rootless"` diff --git a/runsc/container/container.go b/runsc/container/container.go index 418a27beb..8b78660f7 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -128,125 +128,6 @@ type Container struct { goferIsChild bool } -// loadSandbox loads all containers that belong to the sandbox with the given -// ID. -func loadSandbox(rootDir, id string) ([]*Container, error) { - cids, err := List(rootDir) - if err != nil { - return nil, err - } - - // Load the container metadata. - var containers []*Container - for _, cid := range cids { - container, err := Load(rootDir, cid) - if err != nil { - // Container file may not exist if it raced with creation/deletion or - // directory was left behind. Load provides a snapshot in time, so it's - // fine to skip it. - if os.IsNotExist(err) { - continue - } - return nil, fmt.Errorf("loading container %q: %v", id, err) - } - if container.Sandbox.ID == id { - containers = append(containers, container) - } - } - return containers, nil -} - -// Load loads a container with the given id from a metadata file. partialID may -// be an abbreviation of the full container id, in which case Load loads the -// container to which id unambiguously refers to. Returns ErrNotExist if -// container doesn't exist. -func Load(rootDir, partialID string) (*Container, error) { - log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID) - if err := validateID(partialID); err != nil { - return nil, fmt.Errorf("invalid container id: %v", err) - } - - id, err := findContainerID(rootDir, partialID) - if err != nil { - // Preserve error so that callers can distinguish 'not found' errors. - return nil, err - } - - state := StateFile{ - RootDir: rootDir, - ID: id, - } - defer state.close() - - c := &Container{} - if err := state.load(c); err != nil { - if os.IsNotExist(err) { - // Preserve error so that callers can distinguish 'not found' errors. - return nil, err - } - return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) - } - return c, nil -} - -// LoadAndCheck is similar to Load(), but also checks if the container is still -// running to get an error earlier to the caller. -func LoadAndCheck(rootDir, partialID string) (*Container, error) { - c, err := Load(rootDir, partialID) - if err != nil { - // Preserve error so that callers can distinguish 'not found' errors. - return nil, err - } - - // If the status is "Running" or "Created", check that the sandbox/container - // is still running, setting it to Stopped if not. - // - // This is inherently racy. - switch c.Status { - case Created: - if !c.isSandboxRunning() { - // Sandbox no longer exists, so this container definitely does not exist. - c.changeStatus(Stopped) - } - case Running: - if err := c.SignalContainer(syscall.Signal(0), false); err != nil { - c.changeStatus(Stopped) - } - } - - return c, nil -} - -func findContainerID(rootDir, partialID string) (string, error) { - // Check whether the id fully specifies an existing container. - stateFile := buildStatePath(rootDir, partialID) - if _, err := os.Stat(stateFile); err == nil { - return partialID, nil - } - - // Now see whether id could be an abbreviation of exactly 1 of the - // container ids. If id is ambiguous (it could match more than 1 - // container), it is an error. - ids, err := List(rootDir) - if err != nil { - return "", err - } - rv := "" - for _, id := range ids { - if strings.HasPrefix(id, partialID) { - if rv != "" { - return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id) - } - rv = id - } - } - if rv == "" { - return "", os.ErrNotExist - } - log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv) - return rv, nil -} - // Args is used to configure a new container. type Args struct { // ID is the container unique identifier. @@ -291,6 +172,15 @@ func New(conf *config.Config, args Args) (*Container, error) { return nil, fmt.Errorf("creating container root directory %q: %v", conf.RootDir, err) } + sandboxID := args.ID + if !isRoot(args.Spec) { + var ok bool + sandboxID, ok = specutils.SandboxID(args.Spec) + if !ok { + return nil, fmt.Errorf("no sandbox ID found when creating container") + } + } + c := &Container{ ID: args.ID, Spec: args.Spec, @@ -301,7 +191,10 @@ func New(conf *config.Config, args Args) (*Container, error) { Owner: os.Getenv("USER"), Saver: StateFile{ RootDir: conf.RootDir, - ID: args.ID, + ID: FullID{ + SandboxID: sandboxID, + ContainerID: args.ID, + }, }, } // The Cleanup object cleans up partially created containers when an error @@ -316,10 +209,17 @@ func New(conf *config.Config, args Args) (*Container, error) { } defer c.Saver.unlock() - // If the metadata annotations indicate that this container should be - // started in an existing sandbox, we must do so. The metadata will - // indicate the ID of the sandbox, which is the same as the ID of the - // init container in the sandbox. + // If the metadata annotations indicate that this container should be started + // in an existing sandbox, we must do so. These are the possible metadata + // annotation states: + // 1. No annotations: it means that there is a single container and this + // container is obviously the root. Both container and sandbox share the + // ID. + // 2. Container type == sandbox: it means this is the root container + // starting the sandbox. Both container and sandbox share the same ID. + // 3. Container type == container: it means this is a subcontainer of an + // already started sandbox. In this case, container ID is different than + // the sandbox ID. if isRoot(args.Spec) { log.Debugf("Creating new sandbox for container, cid: %s", args.ID) @@ -358,7 +258,7 @@ func New(conf *config.Config, args Args) (*Container, error) { // Start a new sandbox for this container. Any errors after this point // must destroy the container. sandArgs := &sandbox.Args{ - ID: args.ID, + ID: sandboxID, Spec: args.Spec, BundleDir: args.BundleDir, ConsoleSocket: args.ConsoleSocket, @@ -379,22 +279,14 @@ func New(conf *config.Config, args Args) (*Container, error) { return nil, err } } else { - // This is sort of confusing. For a sandbox with a root - // container and a child container in it, runsc sees: - // * A container struct whose sandbox ID is equal to the - // container ID. This is the root container that is tied to - // the creation of the sandbox. - // * A container struct whose sandbox ID is equal to the above - // container/sandbox ID, but that has a different container - // ID. This is the child container. - sbid, ok := specutils.SandboxID(args.Spec) - if !ok { - return nil, fmt.Errorf("no sandbox ID found when creating container") - } - log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sbid) + log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sandboxID) // Find the sandbox associated with this ID. - sb, err := LoadAndCheck(conf.RootDir, sbid) + fullID := FullID{ + SandboxID: sandboxID, + ContainerID: sandboxID, + } + sb, err := Load(conf.RootDir, fullID, LoadOpts{Exact: true}) if err != nil { return nil, err } @@ -628,7 +520,7 @@ func (c *Container) Wait() (syscall.WaitStatus, error) { // returns its WaitStatus. func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID) - if !c.isSandboxRunning() { + if !c.IsSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } return c.Sandbox.WaitPID(c.Sandbox.ID, pid) @@ -638,7 +530,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { // its WaitStatus. func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID) - if !c.isSandboxRunning() { + if !c.IsSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } return c.Sandbox.WaitPID(c.ID, pid) @@ -658,7 +550,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { if err := c.requireStatus("signal", Running, Stopped); err != nil { return err } - if !c.isSandboxRunning() { + if !c.IsSandboxRunning() { return fmt.Errorf("sandbox is not running") } return c.Sandbox.SignalContainer(c.ID, sig, all) @@ -670,7 +562,7 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { if err := c.requireStatus("signal a process inside", Running); err != nil { return err } - if !c.isSandboxRunning() { + if !c.IsSandboxRunning() { return fmt.Errorf("sandbox is not running") } return c.Sandbox.SignalProcess(c.ID, int32(pid), sig, false) @@ -889,7 +781,7 @@ func (c *Container) waitForStopped() error { defer cancel() b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) op := func() error { - if c.isSandboxRunning() { + if c.IsSandboxRunning() { if err := c.SignalContainer(syscall.Signal(0), false); err == nil { return fmt.Errorf("container is still running") } @@ -1091,7 +983,7 @@ func (c *Container) changeStatus(s Status) { c.Status = s } -func (c *Container) isSandboxRunning() bool { +func (c *Container) IsSandboxRunning() bool { return c.Sandbox != nil && c.Sandbox.IsRunning() } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index fa99e403a..a92ae046d 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -364,7 +364,7 @@ func TestLifecycle(t *testing.T) { defer c.Destroy() // Load the container from disk and check the status. - c, err = LoadAndCheck(rootDir, args.ID) + c, err = Load(rootDir, FullID{ContainerID: args.ID}, LoadOpts{}) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -377,7 +377,11 @@ func TestLifecycle(t *testing.T) { if err != nil { t.Fatalf("error listing containers: %v", err) } - if got, want := ids, []string{args.ID}; !reflect.DeepEqual(got, want) { + fullID := FullID{ + SandboxID: args.ID, + ContainerID: args.ID, + } + if got, want := ids, []FullID{fullID}; !reflect.DeepEqual(got, want) { t.Errorf("container list got %v, want %v", got, want) } @@ -387,7 +391,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = LoadAndCheck(rootDir, args.ID) + c, err = Load(rootDir, fullID, LoadOpts{Exact: true}) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -428,7 +432,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = LoadAndCheck(rootDir, args.ID) + c, err = Load(rootDir, fullID, LoadOpts{Exact: true}) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -451,7 +455,7 @@ func TestLifecycle(t *testing.T) { } // Loading the container by id should fail. - if _, err = LoadAndCheck(rootDir, args.ID); err == nil { + if _, err = Load(rootDir, fullID, LoadOpts{Exact: true}); err == nil { t.Errorf("expected loading destroyed container to fail, but it did not") } }) @@ -1738,7 +1742,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { cids[2]: cids[2], } for shortid, longid := range unambiguous { - if _, err := LoadAndCheck(rootDir, shortid); err != nil { + if _, err := Load(rootDir, FullID{ContainerID: shortid}, LoadOpts{}); err != nil { t.Errorf("%q should resolve to %q: %v", shortid, longid, err) } } @@ -1749,7 +1753,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { "ba", } for _, shortid := range ambiguous { - if s, err := LoadAndCheck(rootDir, shortid); err == nil { + if s, err := Load(rootDir, FullID{ContainerID: shortid}, LoadOpts{}); err == nil { t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID) } } @@ -2007,7 +2011,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := LoadAndCheck(rootDir, args.ID) + startCont, err := Load(rootDir, FullID{ContainerID: args.ID}, LoadOpts{}) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -2332,6 +2336,42 @@ func TestTTYField(t *testing.T) { } } +// Test that container can run even when there are corrupt state files in the +// root directiry. +func TestCreateWithCorruptedStateFile(t *testing.T) { + conf := testutil.TestConfig(t) + spec := testutil.NewSpecWithArgs("/bin/true") + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer cleanup() + + // Create corrupted state file. + corruptID := testutil.RandomContainerID() + corruptState := buildPath(conf.RootDir, FullID{SandboxID: corruptID, ContainerID: corruptID}, stateFileExtension) + if err := ioutil.WriteFile(corruptState, []byte("this{file(is;not[valid.json"), 0777); err != nil { + t.Fatalf("createCorruptStateFile(): %v", err) + } + defer os.Remove(corruptState) + + if _, err := Load(conf.RootDir, FullID{ContainerID: corruptID}, LoadOpts{SkipCheck: true}); err == nil { + t.Fatalf("loading corrupted state file should have failed") + } + + args := Args{ + ID: testutil.RandomContainerID(), + Spec: spec, + BundleDir: bundleDir, + Attached: true, + } + if ws, err := Run(conf, args); err != nil { + t.Errorf("running container: %v", err) + } else if !ws.Exited() || ws.ExitStatus() != 0 { + t.Errorf("container failed, waitStatus: %v", ws) + } +} + func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) { args := &control.ExecArgs{ Filename: name, diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 45d4e6e6e..29db1b7e8 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -730,7 +730,7 @@ func TestMultiContainerKillAll(t *testing.T) { // processes still running inside. containers[1].SignalContainer(syscall.SIGKILL, false) op := func() error { - c, err := LoadAndCheck(conf.RootDir, ids[1]) + c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{}) if err != nil { return err } @@ -744,7 +744,7 @@ func TestMultiContainerKillAll(t *testing.T) { } } - c, err := LoadAndCheck(conf.RootDir, ids[1]) + c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{}) if err != nil { t.Fatalf("failed to load child container %q: %v", c.ID, err) } @@ -867,7 +867,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := LoadAndCheck(rootDir, ids[i]) + startCont, err := Load(rootDir, FullID{ContainerID: ids[i]}, LoadOpts{}) if err != nil { t.Fatalf("error loading container: %v", err) } diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index 17a251530..dfbf1f2d3 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -20,58 +20,228 @@ import ( "io/ioutil" "os" "path/filepath" + "regexp" + "strings" + "syscall" "github.com/gofrs/flock" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" ) -const stateFileExtension = ".state" +const stateFileExtension = "state" -// StateFile handles load from/save to container state safely from multiple -// processes. It uses a lock file to provide synchronization between operations. +// LoadOpts provides options for Load()ing a container. +type LoadOpts struct { + // Exact tells whether the search should be exact. See Load() for more. + Exact bool + + // SkipCheck tells Load() to skip checking if container is runnning. + SkipCheck bool +} + +// Load loads a container with the given id from a metadata file. "id" may +// be an abbreviation of the full container id in case LoadOpts.Exact if not +// set. It also checks if the container is still running, in order to return +// an error to the caller earlier. This check is skipped if LoadOpts.SkipCheck +// is set. // -// The lock file is located at: "${s.RootDir}/${s.ID}.lock". -// The state file is located at: "${s.RootDir}/${s.ID}.state". -type StateFile struct { - // RootDir is the directory containing the container metadata file. - RootDir string `json:"rootDir"` +// Returns ErrNotExist if no container is found. Returns error in case more than +// one containers matching the ID prefix is found. +func Load(rootDir string, id FullID, opts LoadOpts) (*Container, error) { + //log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID) + if !opts.Exact { + var err error + id, err = findContainerID(rootDir, id.ContainerID) + if err != nil { + // Preserve error so that callers can distinguish 'not found' errors. + return nil, err + } + } - // ID is the container ID. - ID string `json:"id"` + if err := id.validate(); err != nil { + return nil, fmt.Errorf("invalid container id: %v", err) + } + state := StateFile{ + RootDir: rootDir, + ID: id, + } + defer state.close() - // - // Fields below this line are not saved in the state file and will not - // be preserved across commands. - // + c := &Container{} + if err := state.load(c); err != nil { + if os.IsNotExist(err) { + // Preserve error so that callers can distinguish 'not found' errors. + return nil, err + } + return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) + } - once sync.Once - flock *flock.Flock + if !opts.SkipCheck { + // If the status is "Running" or "Created", check that the sandbox/container + // is still running, setting it to Stopped if not. + // + // This is inherently racy. + switch c.Status { + case Created: + if !c.IsSandboxRunning() { + // Sandbox no longer exists, so this container definitely does not exist. + c.changeStatus(Stopped) + } + case Running: + if err := c.SignalContainer(syscall.Signal(0), false); err != nil { + c.changeStatus(Stopped) + } + } + } + + return c, nil } // List returns all container ids in the given root directory. -func List(rootDir string) ([]string, error) { +func List(rootDir string) ([]FullID, error) { log.Debugf("List containers %q", rootDir) - list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension)) + return listMatch(rootDir, FullID{}) +} + +// listMatch returns all container ids that match the provided id. +func listMatch(rootDir string, id FullID) ([]FullID, error) { + id.SandboxID += "*" + id.ContainerID += "*" + pattern := buildPath(rootDir, id, stateFileExtension) + list, err := filepath.Glob(pattern) if err != nil { return nil, err } - var out []string + var out []FullID for _, path := range list { - // Filter out files that do no belong to a container. - fileName := filepath.Base(path) - if len(fileName) < len(stateFileExtension) { - panic(fmt.Sprintf("invalid file match %q", path)) - } - // Remove the extension. - cid := fileName[:len(fileName)-len(stateFileExtension)] - if validateID(cid) == nil { - out = append(out, cid) + id, err := parseFileName(filepath.Base(path)) + if err == nil { + out = append(out, id) } } return out, nil } +// loadSandbox loads all containers that belong to the sandbox with the given +// ID. +func loadSandbox(rootDir, id string) ([]*Container, error) { + cids, err := listMatch(rootDir, FullID{SandboxID: id}) + if err != nil { + return nil, err + } + + // Load the container metadata. + var containers []*Container + for _, cid := range cids { + container, err := Load(rootDir, cid, LoadOpts{Exact: true, SkipCheck: true}) + if err != nil { + // Container file may not exist if it raced with creation/deletion or + // directory was left behind. Load provides a snapshot in time, so it's + // fine to skip it. + if os.IsNotExist(err) { + continue + } + return nil, fmt.Errorf("loading sandbox %q, failed to load container %q: %v", id, cid, err) + } + containers = append(containers, container) + } + return containers, nil +} + +func findContainerID(rootDir, partialID string) (FullID, error) { + // Check whether the id fully specifies an existing container. + pattern := buildPath(rootDir, FullID{SandboxID: "*", ContainerID: partialID + "*"}, stateFileExtension) + list, err := filepath.Glob(pattern) + if err != nil { + return FullID{}, err + } + switch len(list) { + case 0: + return FullID{}, os.ErrNotExist + case 1: + return parseFileName(filepath.Base(list[0])) + } + + // Now see whether id could be an abbreviation of exactly 1 of the + // container ids. If id is ambiguous (it could match more than 1 + // container), it is an error. + ids, err := List(rootDir) + if err != nil { + return FullID{}, err + } + var rv *FullID + for _, id := range ids { + if strings.HasPrefix(id.ContainerID, partialID) { + if rv != nil { + return FullID{}, fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id) + } + rv = &id + } + } + if rv == nil { + return FullID{}, os.ErrNotExist + } + log.Debugf("abbreviated id %q resolves to full id %v", partialID, *rv) + return *rv, nil +} + +func parseFileName(name string) (FullID, error) { + re := regexp.MustCompile(`([\w+-\.]+)_sandbox:([\w+-\.]+)\.` + stateFileExtension) + groups := re.FindStringSubmatch(name) + if len(groups) != 3 { + return FullID{}, fmt.Errorf("invalid state file name format: %q", name) + } + id := FullID{ + SandboxID: groups[2], + ContainerID: groups[1], + } + if err := id.validate(); err != nil { + return FullID{}, fmt.Errorf("invalid state file name %q: %w", name, err) + } + return id, nil +} + +// FullID combines sandbox and container ID to identify a container. Sandbox ID +// is used to allow all containers for a given sandbox to be loaded by matching +// sandbox ID in the file name. +type FullID struct { + SandboxID string `json:"sandboxId"` + ContainerID string `json:"containerId"` +} + +func (f *FullID) String() string { + return f.SandboxID + "/" + f.ContainerID +} + +func (f *FullID) validate() error { + if err := validateID(f.SandboxID); err != nil { + return err + } + return validateID(f.ContainerID) +} + +// StateFile handles load from/save to container state safely from multiple +// processes. It uses a lock file to provide synchronization between operations. +// +// The lock file is located at: "${s.RootDir}/${containerd-id}_sand:{sandbox-id}.lock". +// The state file is located at: "${s.RootDir}/${containerd-id}_sand:{sandbox-id}.state". +type StateFile struct { + // RootDir is the directory containing the container metadata file. + RootDir string `json:"rootDir"` + + // ID is the sandbox+container ID. + ID FullID `json:"id"` + + // + // Fields below this line are not saved in the state file and will not + // be preserved across commands. + // + + once sync.Once + flock *flock.Flock +} + // lock globally locks all locking operations for the container. func (s *StateFile) lock() error { s.once.Do(func() { @@ -157,18 +327,20 @@ func (s *StateFile) close() error { return s.flock.Close() } -func buildStatePath(rootDir, id string) string { - return filepath.Join(rootDir, id+stateFileExtension) +func buildPath(rootDir string, id FullID, extension string) string { + // Note: "_" and ":" are not valid in IDs. + name := fmt.Sprintf("%s_sandbox:%s.%s", id.ContainerID, id.SandboxID, extension) + return filepath.Join(rootDir, name) } // statePath is the full path to the state file. func (s *StateFile) statePath() string { - return buildStatePath(s.RootDir, s.ID) + return buildPath(s.RootDir, s.ID, stateFileExtension) } // lockPath is the full path to the lock file. func (s *StateFile) lockPath() string { - return filepath.Join(s.RootDir, s.ID+".lock") + return buildPath(s.RootDir, s.ID, "lock") } // destroy deletes all state created by the stateFile. It may be called with the diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 96c57a426..c56e1d4d0 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -29,9 +29,12 @@ go_test( srcs = ["fsgofer_test.go"], library = ":fsgofer", deps = [ + "//pkg/fd", "//pkg/log", "//pkg/p9", "//pkg/test/testutil", + "//runsc/specutils", + "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 0b628c8ce..3d94ffeb4 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -49,6 +49,21 @@ const ( allowedOpenFlags = unix.O_TRUNC ) +var ( + // Remember the process uid/gid to skip chown calls when file owner/group + // doesn't need to be changed. + processUID = p9.UID(os.Getuid()) + processGID = p9.GID(os.Getgid()) +) + +// join is equivalent to path.Join() but skips path.Clean() which is expensive. +func join(parent, child string) string { + if child == "." || child == ".." { + panic(fmt.Sprintf("invalid child path %q", child)) + } + return parent + "/" + child +} + // Config sets configuration options for each attach point. type Config struct { // ROMount is set to true if this is a readonly mount. @@ -115,7 +130,7 @@ func (a *attachPoint) Attach() (p9.File, error) { return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err) } - lf, err := newLocalFile(a, f, a.prefix, readable, stat) + lf, err := newLocalFile(a, f, a.prefix, readable, &stat) if err != nil { return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err) } @@ -124,7 +139,7 @@ func (a *attachPoint) Attach() (p9.File, error) { } // makeQID returns a unique QID for the given stat buffer. -func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID { +func (a *attachPoint) makeQID(stat *unix.Stat_t) p9.QID { a.deviceMu.Lock() defer a.deviceMu.Unlock() @@ -245,7 +260,7 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { } func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) { - pathDebug := path.Join(parent.hostPath, name) + pathDebug := join(parent.hostPath, name) f, readable, err := openAnyFile(pathDebug, func(mode int) (*fd.FD, error) { return fd.OpenAt(parent.file, name, openFlags|mode, 0) }) @@ -297,8 +312,8 @@ func openAnyFile(pathDebug string, fn func(mode int) (*fd.FD, error)) (*fd.FD, b return nil, false, extractErrno(err) } -func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error { - switch stat.Mode & unix.S_IFMT { +func checkSupportedFileType(mode uint32, permitSocket bool) error { + switch mode & unix.S_IFMT { case unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK: return nil @@ -313,8 +328,8 @@ func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error { } } -func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat unix.Stat_t) (*localFile, error) { - if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil { +func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat *unix.Stat_t) (*localFile, error) { + if err := checkSupportedFileType(stat.Mode, a.conf.HostUDS); err != nil { return nil, err } @@ -442,8 +457,10 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode, }) defer cu.Clean() - if err := fchown(child.FD(), uid, gid); err != nil { - return nil, nil, p9.QID{}, 0, extractErrno(err) + if uid != processUID || gid != processGID { + if err := fchown(child.FD(), uid, gid); err != nil { + return nil, nil, p9.QID{}, 0, extractErrno(err) + } } stat, err := fstat(child.FD()) if err != nil { @@ -452,11 +469,11 @@ func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode, c := &localFile{ attachPoint: l.attachPoint, - hostPath: path.Join(l.hostPath, name), + hostPath: join(l.hostPath, name), file: child, mode: mode, fileType: unix.S_IFREG, - qid: l.attachPoint.makeQID(stat), + qid: l.attachPoint.makeQID(&stat), } cu.Release() @@ -488,8 +505,10 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) } defer f.Close() - if err := fchown(f.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) + if uid != processUID || gid != processGID { + if err := fchown(f.FD(), uid, gid); err != nil { + return p9.QID{}, extractErrno(err) + } } stat, err := fstat(f.FD()) if err != nil { @@ -497,7 +516,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) } cu.Release() - return l.attachPoint.makeQID(stat), nil + return l.attachPoint.makeQID(&stat), nil } // Walk implements p9.File. @@ -512,7 +531,7 @@ func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, if err != nil { return nil, nil, p9.AttrMask{}, p9.Attr{}, err } - mask, attr := l.fillAttr(stat) + mask, attr := l.fillAttr(&stat) return qids, file, mask, attr, nil } @@ -538,13 +557,13 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error) file: newFile, mode: invalidMode, fileType: l.fileType, - qid: l.attachPoint.makeQID(stat), + qid: l.attachPoint.makeQID(&stat), controlReadable: readable, } return []p9.QID{c.qid}, c, stat, nil } - var qids []p9.QID + qids := make([]p9.QID, 0, len(names)) var lastStat unix.Stat_t last := l for _, name := range names { @@ -560,7 +579,7 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error) _ = f.Close() return nil, nil, unix.Stat_t{}, extractErrno(err) } - c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat) + c, err := newLocalFile(last.attachPoint, f, path, readable, &lastStat) if err != nil { _ = f.Close() return nil, nil, unix.Stat_t{}, extractErrno(err) @@ -609,11 +628,11 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) if err != nil { return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err) } - mask, attr := l.fillAttr(stat) + mask, attr := l.fillAttr(&stat) return l.qid, mask, attr, nil } -func (l *localFile) fillAttr(stat unix.Stat_t) (p9.AttrMask, p9.Attr) { +func (l *localFile) fillAttr(stat *unix.Stat_t) (p9.AttrMask, p9.Attr) { attr := p9.Attr{ Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), @@ -881,8 +900,10 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. } defer f.Close() - if err := fchown(f.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) + if uid != processUID || gid != processGID { + if err := fchown(f.FD(), uid, gid); err != nil { + return p9.QID{}, extractErrno(err) + } } stat, err := fstat(f.FD()) if err != nil { @@ -890,7 +911,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. } cu.Release() - return l.attachPoint.makeQID(stat), nil + return l.attachPoint.makeQID(&stat), nil } // Link implements p9.File. @@ -938,8 +959,10 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid } defer child.Close() - if err := fchown(child.FD(), uid, gid); err != nil { - return p9.QID{}, extractErrno(err) + if uid != processUID || gid != processGID { + if err := fchown(child.FD(), uid, gid); err != nil { + return p9.QID{}, extractErrno(err) + } } stat, err := fstat(child.FD()) if err != nil { @@ -947,7 +970,7 @@ func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid } cu.Release() - return l.attachPoint.makeQID(stat), nil + return l.attachPoint.makeQID(&stat), nil } // UnlinkAt implements p9.File. @@ -1045,7 +1068,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) log.Warningf("Readdir is skipping file with failed stat %q, err: %v", l.hostPath, err) continue } - qid := l.attachPoint.makeQID(stat) + qid := l.attachPoint.makeQID(&stat) offset++ dirents = append(dirents, p9.Dirent{ QID: qid, @@ -1139,7 +1162,7 @@ func (l *localFile) isOpen() bool { // Renamed implements p9.Renamed. func (l *localFile) Renamed(newDir p9.File, newName string) { - l.hostPath = path.Join(newDir.(*localFile).hostPath, newName) + l.hostPath = join(newDir.(*localFile).hostPath, newName) } // extractErrno tries to determine the errno. diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index a84206686..c5daebe5e 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -23,10 +23,13 @@ import ( "path/filepath" "testing" + "github.com/syndtr/gocapability/capability" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/runsc/specutils" ) var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} @@ -197,10 +200,13 @@ func setup(fileType uint32) (string, string, error) { switch fileType { case unix.S_IFREG: name = "file" - _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) + fd, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) } + if fd != nil { + fd.Close() + } defer f.Close() case unix.S_IFDIR: name = "dir" @@ -556,7 +562,28 @@ func TestROMountChecks(t *testing.T) { func TestWalkNotFound(t *testing.T) { runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT { - t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: unix.ENOENT", s, "nobody-here", err) + t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody-here", err) + } + if _, _, err := s.file.Walk([]string{"nobody", "here"}); err != unix.ENOENT { + t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "nobody/here", err) + } + if !s.conf.ROMount { + if _, err := s.file.Mkdir("dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { + t.Fatalf("MkDir(dir) failed, err: %v", err) + } + if _, _, err := s.file.Walk([]string{"dir", "nobody-here"}); err != unix.ENOENT { + t.Errorf("Walk(%q) should have failed, got: %v, expected: unix.ENOENT", "dir/nobody-here", err) + } + } + }) +} + +func TestWalkPanic(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { + for _, name := range []string{".", ".."} { + assertPanic(t, func() { + s.file.Walk([]string{name}) + }) } }) } @@ -574,6 +601,27 @@ func TestWalkDup(t *testing.T) { }) } +func TestWalkMultiple(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + var names []string + var parent p9.File = s.file + for i := 0; i < 5; i++ { + name := fmt.Sprintf("dir%d", i) + names = append(names, name) + + if _, err := parent.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { + t.Fatalf("MkDir(%q) failed, err: %v", name, err) + } + + var err error + _, parent, err = s.file.Walk(names) + if err != nil { + t.Errorf("Walk(%q): %v", name, err) + } + } + }) +} + func TestReaddir(t *testing.T) { runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { name := "dir" @@ -819,3 +867,168 @@ func TestMknod(t *testing.T) { } }) } + +func BenchmarkWalkOne(b *testing.B) { + path, name, err := setup(unix.S_IFDIR) + if err != nil { + b.Fatalf("%v", err) + } + defer os.RemoveAll(path) + + a, err := NewAttachPoint(path, Config{}) + if err != nil { + b.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + b.Fatalf("Attach failed, err: %v", err) + } + defer root.Close() + + names := []string{name} + files := make([]p9.File, 0, 1000) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, file, err := root.Walk(names) + if err != nil { + b.Fatalf("Walk(%q): %v", name, err) + } + files = append(files, file) + + // Avoid running out of FDs. + if len(files) == cap(files) { + b.StopTimer() + for _, file := range files { + file.Close() + } + files = files[:0] + b.StartTimer() + } + } + + b.StopTimer() + for _, file := range files { + file.Close() + } +} + +func BenchmarkCreate(b *testing.B) { + path, _, err := setup(unix.S_IFDIR) + if err != nil { + b.Fatalf("%v", err) + } + defer os.RemoveAll(path) + + a, err := NewAttachPoint(path, Config{}) + if err != nil { + b.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + b.Fatalf("Attach failed, err: %v", err) + } + defer root.Close() + + files := make([]p9.File, 0, 500) + fds := make([]*fd.FD, 0, 500) + uid := p9.UID(os.Getuid()) + gid := p9.GID(os.Getgid()) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + name := fmt.Sprintf("same-%d", i) + fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, uid, gid) + if err != nil { + b.Fatalf("Create(%q): %v", name, err) + } + files = append(files, file) + if fd != nil { + fds = append(fds, fd) + } + + // Avoid running out of FDs. + if len(files) == cap(files) { + b.StopTimer() + for _, file := range files { + file.Close() + } + files = files[:0] + for _, fd := range fds { + fd.Close() + } + fds = fds[:0] + b.StartTimer() + } + } + + b.StopTimer() + for _, file := range files { + file.Close() + } + for _, fd := range fds { + fd.Close() + } +} + +func BenchmarkCreateDiffOwner(b *testing.B) { + if !specutils.HasCapabilities(capability.CAP_CHOWN) { + b.Skipf("Test requires CAP_CHOWN") + } + + path, _, err := setup(unix.S_IFDIR) + if err != nil { + b.Fatalf("%v", err) + } + defer os.RemoveAll(path) + + a, err := NewAttachPoint(path, Config{}) + if err != nil { + b.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + b.Fatalf("Attach failed, err: %v", err) + } + defer root.Close() + + files := make([]p9.File, 0, 500) + fds := make([]*fd.FD, 0, 500) + gid := p9.GID(os.Getgid()) + const nobody = 65534 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + name := fmt.Sprintf("diff-%d", i) + fd, file, _, _, err := root.Create(name, p9.ReadOnly, 0777, nobody, gid) + if err != nil { + b.Fatalf("Create(%q): %v", name, err) + } + files = append(files, file) + if fd != nil { + fds = append(fds, fd) + } + + // Avoid running out of FDs. + if len(files) == cap(files) { + b.StopTimer() + for _, file := range files { + file.Close() + } + files = files[:0] + for _, fd := range fds { + fd.Close() + } + fds = fds[:0] + b.StartTimer() + } + } + + b.StopTimer() + for _, file := range files { + file.Close() + } + for _, fd := range fds { + fd.Close() + } +} diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index d8112e7a2..9e429f7d5 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -279,8 +279,6 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( ll := syscall.SockaddrLinklayer{ Protocol: protocol, Ifindex: iface.Index, - Hatype: 0, // No ARP type. - Pkttype: syscall.PACKET_OTHERHOST, } if err := syscall.Bind(fd, &ll); err != nil { return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index c84ebcd8a..c1d13a58d 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -999,54 +999,30 @@ func (s *Sandbox) HeapProfile(f *os.File) error { } defer conn.Close() - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, + opts := control.HeapProfileOpts{ + FilePayload: urpc.FilePayload{Files: []*os.File{f}}, } - if err := conn.Call(boot.HeapProfile, &opts, nil); err != nil { - return fmt.Errorf("getting sandbox %q heap profile: %v", s.ID, err) - } - return nil + return conn.Call(boot.HeapProfile, &opts, nil) } -// StartCPUProfile start CPU profile writing to the given file. -func (s *Sandbox) StartCPUProfile(f *os.File) error { - log.Debugf("CPU profile start %q", s.ID) +// CPUProfile collects a CPU profile. +func (s *Sandbox) CPUProfile(f *os.File, duration time.Duration) error { + log.Debugf("CPU profile %q", s.ID) conn, err := s.sandboxConnect() if err != nil { return err } defer conn.Close() - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, - } - if err := conn.Call(boot.StartCPUProfile, &opts, nil); err != nil { - return fmt.Errorf("starting sandbox %q CPU profile: %v", s.ID, err) + opts := control.CPUProfileOpts{ + FilePayload: urpc.FilePayload{Files: []*os.File{f}}, + Duration: duration, } - return nil -} - -// StopCPUProfile stops a previously started CPU profile. -func (s *Sandbox) StopCPUProfile() error { - log.Debugf("CPU profile stop %q", s.ID) - conn, err := s.sandboxConnect() - if err != nil { - return err - } - defer conn.Close() - - if err := conn.Call(boot.StopCPUProfile, nil, nil); err != nil { - return fmt.Errorf("stopping sandbox %q CPU profile: %v", s.ID, err) - } - return nil + return conn.Call(boot.CPUProfile, &opts, nil) } // BlockProfile writes a block profile to the given file. -func (s *Sandbox) BlockProfile(f *os.File) error { +func (s *Sandbox) BlockProfile(f *os.File, duration time.Duration) error { log.Debugf("Block profile %q", s.ID) conn, err := s.sandboxConnect() if err != nil { @@ -1054,19 +1030,15 @@ func (s *Sandbox) BlockProfile(f *os.File) error { } defer conn.Close() - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, + opts := control.BlockProfileOpts{ + FilePayload: urpc.FilePayload{Files: []*os.File{f}}, + Duration: duration, } - if err := conn.Call(boot.BlockProfile, &opts, nil); err != nil { - return fmt.Errorf("getting sandbox %q block profile: %v", s.ID, err) - } - return nil + return conn.Call(boot.BlockProfile, &opts, nil) } // MutexProfile writes a mutex profile to the given file. -func (s *Sandbox) MutexProfile(f *os.File) error { +func (s *Sandbox) MutexProfile(f *os.File, duration time.Duration) error { log.Debugf("Mutex profile %q", s.ID) conn, err := s.sandboxConnect() if err != nil { @@ -1074,50 +1046,27 @@ func (s *Sandbox) MutexProfile(f *os.File) error { } defer conn.Close() - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, - } - if err := conn.Call(boot.MutexProfile, &opts, nil); err != nil { - return fmt.Errorf("getting sandbox %q mutex profile: %v", s.ID, err) - } - return nil -} - -// StartTrace start trace writing to the given file. -func (s *Sandbox) StartTrace(f *os.File) error { - log.Debugf("Trace start %q", s.ID) - conn, err := s.sandboxConnect() - if err != nil { - return err - } - defer conn.Close() - - opts := control.ProfileOpts{ - FilePayload: urpc.FilePayload{ - Files: []*os.File{f}, - }, - } - if err := conn.Call(boot.StartTrace, &opts, nil); err != nil { - return fmt.Errorf("starting sandbox %q trace: %v", s.ID, err) + opts := control.MutexProfileOpts{ + FilePayload: urpc.FilePayload{Files: []*os.File{f}}, + Duration: duration, } - return nil + return conn.Call(boot.MutexProfile, &opts, nil) } -// StopTrace stops a previously started trace. -func (s *Sandbox) StopTrace() error { - log.Debugf("Trace stop %q", s.ID) +// Trace collects an execution trace. +func (s *Sandbox) Trace(f *os.File, duration time.Duration) error { + log.Debugf("Trace %q", s.ID) conn, err := s.sandboxConnect() if err != nil { return err } defer conn.Close() - if err := conn.Call(boot.StopTrace, nil, nil); err != nil { - return fmt.Errorf("stopping sandbox %q trace: %v", s.ID, err) + opts := control.TraceProfileOpts{ + FilePayload: urpc.FilePayload{Files: []*os.File{f}}, + Duration: duration, } - return nil + return conn.Call(boot.Trace, &opts, nil) } // ChangeLogging changes logging options. diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md index d1bbabf6f..1bfb4a129 100644 --- a/test/benchmarks/README.md +++ b/test/benchmarks/README.md @@ -81,11 +81,8 @@ benchmarks. In general, benchmarks should look like this: ```golang - -var h harness.Harness - func BenchmarkMyCoolOne(b *testing.B) { - machine, err := h.GetMachine() + machine, err := harness.GetMachine() // check err defer machine.CleanUp() @@ -95,14 +92,14 @@ func BenchmarkMyCoolOne(b *testing.B) { b.ResetTimer() - //Respect b.N. + // Respect b.N. for i := 0; i < b.N; i++ { out, err := container.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/my-cool-image", Env: []string{"MY_VAR=awesome"}, other options...see dockerutil }, "sh", "-c", "echo MY_VAR") - //check err + // check err... b.StopTimer() // Do parsing and reporting outside of the timer. @@ -114,16 +111,13 @@ func BenchmarkMyCoolOne(b *testing.B) { } func TestMain(m *testing.M) { - h.Init() + harness.Init() os.Exit(m.Run()) } ``` Some notes on the above: -* The harness is initiated in the TestMain method and made global to test - module. The harness will handle any presetup that needs to happen with - flags, remote virtual machines (eventually), and other services. * Respect `b.N` in that users of the benchmark may want to "run for an hour" or something of the sort. * Use the `b.ReportMetric()` method to report custom metrics. diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go index acc49cc7c..452926e5f 100644 --- a/test/benchmarks/base/size_test.go +++ b/test/benchmarks/base/size_test.go @@ -26,12 +26,10 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) -var testHarness harness.Harness - // BenchmarkSizeEmpty creates N empty containers and reads memory usage from // /proc/meminfo. func BenchmarkSizeEmpty(b *testing.B) { - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -81,7 +79,7 @@ func BenchmarkSizeEmpty(b *testing.B) { // BenchmarkSizeNginx starts N containers running Nginx, checks that they're // serving, and checks memory used based on /proc/meminfo. func BenchmarkSizeNginx(b *testing.B) { - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -126,7 +124,7 @@ func BenchmarkSizeNginx(b *testing.B) { // BenchmarkSizeNode starts N containers running a Node app, checks that // they're serving, and checks memory used based on /proc/meminfo. func BenchmarkSizeNode(b *testing.B) { - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -178,6 +176,6 @@ func BenchmarkSizeNode(b *testing.B) { // TestMain is the main method for package network. func TestMain(m *testing.M) { - testHarness.Init() + harness.Init() os.Exit(m.Run()) } diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go index 8ef9f99c4..05a43ad17 100644 --- a/test/benchmarks/base/startup_test.go +++ b/test/benchmarks/base/startup_test.go @@ -25,11 +25,9 @@ import ( "gvisor.dev/gvisor/test/benchmarks/harness" ) -var testHarness harness.Harness - // BenchmarkStartEmpty times startup time for an empty container. func BenchmarkStartupEmpty(b *testing.B) { - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -53,7 +51,7 @@ func BenchmarkStartupEmpty(b *testing.B) { // Time is measured from start until the first request is served. func BenchmarkStartupNginx(b *testing.B) { // The machine to hold Nginx and the Node Server. - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -76,7 +74,7 @@ func BenchmarkStartupNginx(b *testing.B) { // Time is measured from start until the first request is served. // Note that the Node app connects to a Redis instance before serving. func BenchmarkStartupNode(b *testing.B) { - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -126,8 +124,8 @@ func runServerWorkload(ctx context.Context, b *testing.B, args base.ServerArgs) return fmt.Errorf("failed to get ip from server: %v", err) } - harness.DebugLog(b, "Waiting for container to start.") // Wait until the Client sees the server as up. + harness.DebugLog(b, "Waiting for container to start.") if err := harness.WaitUntilServing(ctx, args.Machine, servingIP, args.Port); err != nil { return fmt.Errorf("failed to wait for serving: %v", err) } @@ -141,6 +139,6 @@ func runServerWorkload(ctx context.Context, b *testing.B, args base.ServerArgs) // TestMain is the main method for package network. func TestMain(m *testing.M) { - testHarness.Init() + harness.Init() os.Exit(m.Run()) } diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go index bbb797e14..80569687c 100644 --- a/test/benchmarks/base/sysbench_test.go +++ b/test/benchmarks/base/sysbench_test.go @@ -23,8 +23,6 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) -var testHarness harness.Harness - type testCase struct { name string test tools.Sysbench @@ -32,42 +30,34 @@ type testCase struct { // BenchmarSysbench runs sysbench on the runtime. func BenchmarkSysbench(b *testing.B) { - testCases := []testCase{ testCase{ name: "CPU", test: &tools.SysbenchCPU{ - Base: tools.SysbenchBase{ + SysbenchBase: tools.SysbenchBase{ Threads: 1, - Time: 5, }, - MaxPrime: 50000, }, }, testCase{ name: "Memory", test: &tools.SysbenchMemory{ - Base: tools.SysbenchBase{ + SysbenchBase: tools.SysbenchBase{ Threads: 1, }, - BlockSize: "1M", - TotalSize: "500G", }, }, testCase{ name: "Mutex", test: &tools.SysbenchMutex{ - Base: tools.SysbenchBase{ + SysbenchBase: tools.SysbenchBase{ Threads: 8, }, - Loops: 1, - Locks: 10000000, - Num: 4, }, }, } - machine, err := testHarness.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -87,12 +77,15 @@ func BenchmarkSysbench(b *testing.B) { sysbench := machine.GetContainer(ctx, b) defer sysbench.CleanUp(ctx) + cmd := tc.test.MakeCmd(b) + b.ResetTimer() out, err := sysbench.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/sysbench", - }, tc.test.MakeCmd()...) + }, cmd...) if err != nil { b.Fatalf("failed to run sysbench: %v: logs:%s", err, out) } + b.StopTimer() tc.test.Report(b, out) }) } diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD index bfa7f71b6..0b1743603 100644 --- a/test/benchmarks/database/BUILD +++ b/test/benchmarks/database/BUILD @@ -7,11 +7,10 @@ go_library( name = "database", testonly = 1, srcs = ["database.go"], - deps = ["//test/benchmarks/harness"], ) benchmark_test( - name = "database_test", + name = "redis_test", size = "enormous", srcs = ["redis_test.go"], library = ":database", diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go index 9eeb59f9a..c15ca661c 100644 --- a/test/benchmarks/database/database.go +++ b/test/benchmarks/database/database.go @@ -14,18 +14,3 @@ // Package database holds benchmarks around database applications. package database - -import ( - "os" - "testing" - - "gvisor.dev/gvisor/test/benchmarks/harness" -) - -var h harness.Harness - -// TestMain is the main method for package database. -func TestMain(m *testing.M) { - h.Init() - os.Exit(m.Run()) -} diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go index 02e67154e..f3c4522ac 100644 --- a/test/benchmarks/database/redis_test.go +++ b/test/benchmarks/database/redis_test.go @@ -16,6 +16,7 @@ package database import ( "context" + "os" "testing" "time" @@ -49,13 +50,13 @@ var operations []string = []string{ // BenchmarkRedis runs redis-benchmark against a redis instance and reports // data in queries per second. Each is reported by named operation (e.g. LPUSH). func BenchmarkRedis(b *testing.B) { - clientMachine, err := h.GetMachine() + clientMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } defer clientMachine.CleanUp() - serverMachine, err := h.GetMachine() + serverMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -64,7 +65,6 @@ func BenchmarkRedis(b *testing.B) { // Redis runs on port 6379 by default. port := 6379 ctx := context.Background() - for _, operation := range operations { param := tools.Parameter{ Name: "operation", @@ -104,28 +104,26 @@ func BenchmarkRedis(b *testing.B) { b.Fatalf("failed to start redis with: %v", err) } + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + redis := tools.Redis{ Operation: operation, } - - // Reset profiles and timer to begin the measurement. - server.RestartProfiles() b.ResetTimer() - for i := 0; i < b.N; i++ { - client := clientMachine.GetNativeContainer(ctx, b) - defer client.CleanUp(ctx) - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/redis", - }, redis.MakeCmd(ip, serverPort)...) - if err != nil { - b.Fatalf("redis-benchmark failed with: %v", err) - } - - // Stop time while we parse results. - b.StopTimer() - redis.Report(b, out) - b.StartTimer() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }, redis.MakeCmd(ip, serverPort, b.N /*requests*/)...) + if err != nil { + b.Fatalf("redis-benchmark failed with: %v", err) } + b.StopTimer() + redis.Report(b, out) }) } } + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index 53ed3f9f2..8baeff0db 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -25,8 +25,6 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) -var h harness.Harness - // Note: CleanCache versions of this test require running with root permissions. func BenchmarkBuildABSL(b *testing.B) { runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...") @@ -41,7 +39,7 @@ func BenchmarkBuildRunsc(b *testing.B) { func runBuildBenchmark(b *testing.B, image, workdir, target string) { b.Helper() // Get a machine from the Harness on which to run. - machine, err := h.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -61,10 +59,10 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { for _, bm := range benchmarks { pageCache := tools.Parameter{ Name: "page_cache", - Value: "clean", + Value: "dirty", } if bm.clearCache { - pageCache.Value = "dirty" + pageCache.Value = "clean" } filesystem := tools.Parameter{ @@ -102,21 +100,20 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { prefix = "/tmp" } - // Restart profiles after the copy. - container.RestartProfiles() b.ResetTimer() + b.StopTimer() + // Drop Caches and bazel clean should happen inside the loop as we may use // time options with b.N. (e.g. Run for an hour.) for i := 0; i < b.N; i++ { - b.StopTimer() // Drop Caches for clear cache runs. if bm.clearCache { if err := harness.DropCaches(machine); err != nil { b.Skipf("failed to drop caches: %v. You probably need root.", err) } } - b.StartTimer() + b.StartTimer() got, err := container.Exec(ctx, dockerutil.ExecOpts{ WorkDir: prefix + workdir, }, "bazel", "build", "-c", "opt", target) @@ -129,14 +126,15 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { if !strings.Contains(got, want) { b.Fatalf("string %s not in: %s", want, got) } - // Clean bazel in case we use b.N. - _, err = container.Exec(ctx, dockerutil.ExecOpts{ - WorkDir: prefix + workdir, - }, "bazel", "clean") - if err != nil { - b.Fatalf("build failed with: %v", err) + + // Clean bazel in the case we are doing another run. + if i < b.N-1 { + if _, err = container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: prefix + workdir, + }, "bazel", "clean"); err != nil { + b.Fatalf("build failed with: %v", err) + } } - b.StartTimer() } }) } @@ -144,6 +142,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { // TestMain is the main method for package fs. func TestMain(m *testing.M) { - h.Init() + harness.Init() + harness.SetFixedBenchmarks() os.Exit(m.Run()) } diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go index 96340373c..0c772b768 100644 --- a/test/benchmarks/fs/fio_test.go +++ b/test/benchmarks/fs/fio_test.go @@ -27,42 +27,50 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) -var h harness.Harness - // BenchmarkFio runs fio on the runtime under test. There are 4 basic test // cases each run on a tmpfs mount and a bind mount. Fio requires root so that // caches can be dropped. func BenchmarkFio(b *testing.B) { testCases := []tools.Fio{ tools.Fio{ - Test: "write", - Size: "5G", - Blocksize: "1M", - Iodepth: 4, + Test: "write4K", + Size: b.N, + BlockSize: 4, + IODepth: 4, + }, + tools.Fio{ + Test: "write1M", + Size: b.N, + BlockSize: 1024, + IODepth: 4, + }, + tools.Fio{ + Test: "read4K", + Size: b.N, + BlockSize: 4, + IODepth: 4, }, tools.Fio{ - Test: "read", - Size: "5G", - Blocksize: "1M", - Iodepth: 4, + Test: "read1M", + Size: b.N, + BlockSize: 1024, + IODepth: 4, }, tools.Fio{ - Test: "randwrite", - Size: "5G", - Blocksize: "4K", - Iodepth: 4, - Time: 30, + Test: "randwrite4K", + Size: b.N, + BlockSize: 4, + IODepth: 4, }, tools.Fio{ - Test: "randread", - Size: "5G", - Blocksize: "4K", - Iodepth: 4, - Time: 30, + Test: "randread4K", + Size: b.N, + BlockSize: 4, + IODepth: 4, }, } - machine, err := h.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -116,7 +124,7 @@ func BenchmarkFio(b *testing.B) { // For reads, we need a file to read so make one inside the container. if strings.Contains(tc.Test, "read") { - fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.Size, outfile) + fallocateCmd := fmt.Sprintf("fallocate -l %dK %s", tc.Size, outfile) if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, strings.Split(fallocateCmd, " ")...); err != nil { b.Fatalf("failed to create readable file on mount: %v, %s", err, out) @@ -128,22 +136,24 @@ func BenchmarkFio(b *testing.B) { b.Skipf("failed to drop caches with %v. You probably need root.", err) } cmd := tc.MakeCmd(outfile) - container.RestartProfiles() + b.ResetTimer() + b.StopTimer() + for i := 0; i < b.N; i++ { + if err := harness.DropCaches(machine); err != nil { + b.Fatalf("failed to drop caches: %v", err) + } + // Run fio. + b.StartTimer() data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...) if err != nil { b.Fatalf("failed to run cmd %v: %v", cmd, err) } b.StopTimer() + b.SetBytes(1024 * 1024) // Bytes for go reporting (Size is in megabytes). tc.Report(b, data) - // If b.N is used (i.e. we run for an hour), we should drop caches - // after each run. - if err := harness.DropCaches(machine); err != nil { - b.Fatalf("failed to drop caches: %v", err) - } - b.StartTimer() } }) } @@ -185,6 +195,6 @@ func makeMount(machine harness.Machine, mountType mount.Type, target string) (mo // TestMain is the main method for package fs. func TestMain(m *testing.M) { - h.Init() + harness.Init() os.Exit(m.Run()) } diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go index 4c6e724aa..a853b7ba8 100644 --- a/test/benchmarks/harness/harness.go +++ b/test/benchmarks/harness/harness.go @@ -28,12 +28,8 @@ var ( debug = flag.Bool("debug", false, "turns on debug messages for individual benchmarks") ) -// Harness is a handle for managing state in benchmark runs. -type Harness struct { -} - // Init performs any harness initilialization before runs. -func (h *Harness) Init() error { +func Init() error { flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s -- --test.bench=<regex>\n", os.Args[0]) flag.PrintDefaults() @@ -47,7 +43,15 @@ func (h *Harness) Init() error { return nil } +// SetFixedBenchmarks causes all benchmarks to run once. +// +// This must be set if they cannot scale with N. Note that this uses 1ns +// instead of 1x due to https://github.com/golang/go/issues/32051. +func SetFixedBenchmarks() { + flag.Set("test.benchtime", "1ns") +} + // GetMachine returns this run's implementation of machine. -func (h *Harness) GetMachine() (Machine, error) { +func GetMachine() (Machine, error) { return &localMachine{}, nil } diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD index 46e8dc8b5..380783f0b 100644 --- a/test/benchmarks/media/BUILD +++ b/test/benchmarks/media/BUILD @@ -7,12 +7,11 @@ go_library( name = "media", testonly = 1, srcs = ["media.go"], - deps = ["//test/benchmarks/harness"], ) benchmark_test( - name = "media_test", - size = "large", + name = "ffmpeg_test", + size = "enormous", srcs = ["ffmpeg_test.go"], library = ":media", visibility = ["//:sandbox"], diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go index 7822dfad7..1b99a319a 100644 --- a/test/benchmarks/media/ffmpeg_test.go +++ b/test/benchmarks/media/ffmpeg_test.go @@ -15,6 +15,7 @@ package media import ( "context" + "os" "strings" "testing" @@ -25,29 +26,36 @@ import ( // BenchmarkFfmpeg runs ffmpeg in a container and records runtime. // BenchmarkFfmpeg should run as root to drop caches. func BenchmarkFfmpeg(b *testing.B) { - machine, err := h.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } defer machine.CleanUp() ctx := context.Background() - container := machine.GetContainer(ctx, b) - defer container.CleanUp(ctx) cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ") b.ResetTimer() + b.StopTimer() + for i := 0; i < b.N; i++ { - b.StopTimer() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) if err := harness.DropCaches(machine); err != nil { b.Skipf("failed to drop caches: %v. You probably need root.", err) } - b.StartTimer() + b.StartTimer() if _, err := container.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/ffmpeg", }, cmd...); err != nil { b.Fatalf("failed to run container: %v", err) } + b.StopTimer() } } + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go index c7b35b758..ed7b24651 100644 --- a/test/benchmarks/media/media.go +++ b/test/benchmarks/media/media.go @@ -14,18 +14,3 @@ // Package media holds benchmarks around media processing applications. package media - -import ( - "os" - "testing" - - "gvisor.dev/gvisor/test/benchmarks/harness" -) - -var h harness.Harness - -// TestMain is the main method for package media. -func TestMain(m *testing.M) { - h.Init() - os.Exit(m.Run()) -} diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 02ff6966f..285ec35d9 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -7,12 +7,11 @@ go_library( name = "ml", testonly = 1, srcs = ["ml.go"], - deps = ["//test/benchmarks/harness"], ) benchmark_test( - name = "ml_test", - size = "large", + name = "tensorflow_test", + size = "enormous", srcs = ["tensorflow_test.go"], library = ":ml", visibility = ["//:sandbox"], diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go index 13282d7bb..d5fc5b7da 100644 --- a/test/benchmarks/ml/ml.go +++ b/test/benchmarks/ml/ml.go @@ -14,18 +14,3 @@ // Package ml holds benchmarks around machine learning performance. package ml - -import ( - "os" - "testing" - - "gvisor.dev/gvisor/test/benchmarks/harness" -) - -var h harness.Harness - -// TestMain is the main method for package ml. -func TestMain(m *testing.M) { - h.Init() - os.Exit(m.Run()) -} diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go index f7746897d..b0e0c4720 100644 --- a/test/benchmarks/ml/tensorflow_test.go +++ b/test/benchmarks/ml/tensorflow_test.go @@ -15,6 +15,7 @@ package ml import ( "context" + "os" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -35,7 +36,7 @@ func BenchmarkTensorflow(b *testing.B) { "NeuralNetwork": "3_NeuralNetworks/neural_network.py", } - machine, err := h.GetMachine() + machine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -44,17 +45,19 @@ func BenchmarkTensorflow(b *testing.B) { for name, workload := range workloads { b.Run(name, func(b *testing.B) { ctx := context.Background() - container := machine.GetContainer(ctx, b) - defer container.CleanUp(ctx) b.ResetTimer() + b.StopTimer() + for i := 0; i < b.N; i++ { - b.StopTimer() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) if err := harness.DropCaches(machine); err != nil { b.Skipf("failed to drop caches: %v. You probably need root.", err) } - b.StartTimer() + // Run tensorflow. + b.StartTimer() if out, err := container.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/tensorflow", Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"}, @@ -62,8 +65,14 @@ func BenchmarkTensorflow(b *testing.B) { }, "python", workload); err != nil { b.Fatalf("failed to run container: %v logs: %s", err, out) } + b.StopTimer() } }) } +} +func TestMain(m *testing.M) { + harness.Init() + harness.SetFixedBenchmarks() + os.Exit(m.Run()) } diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index c75d1ce11..2741570f5 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -8,7 +8,6 @@ go_library( testonly = 1, srcs = [ "network.go", - "static_server.go", ], deps = [ "//pkg/test/dockerutil", @@ -18,19 +17,76 @@ go_library( ) benchmark_test( - name = "network_test", - size = "large", + name = "iperf_test", + size = "enormous", srcs = [ - "httpd_test.go", "iperf_test.go", - "nginx_test.go", + ], + library = ":network", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) + +benchmark_test( + name = "node_test", + size = "enormous", + srcs = [ "node_test.go", + ], + library = ":network", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) + +benchmark_test( + name = "ruby_test", + size = "enormous", + srcs = [ "ruby_test.go", ], library = ":network", visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) + +benchmark_test( + name = "nginx_test", + size = "enormous", + srcs = [ + "nginx_test.go", + ], + library = ":network", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) + +benchmark_test( + name = "httpd_test", + size = "enormous", + srcs = [ + "httpd_test.go", + ], + library = ":network", + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", "//pkg/test/testutil", "//test/benchmarks/harness", "//test/benchmarks/tools", diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go index 8d7d5f750..629127250 100644 --- a/test/benchmarks/network/httpd_test.go +++ b/test/benchmarks/network/httpd_test.go @@ -14,10 +14,12 @@ package network import ( + "os" "strconv" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) @@ -34,18 +36,20 @@ var httpdDocs = map[string]string{ // BenchmarkHttpd iterates over different sized payloads and concurrency, testing // how well the runtime handles sending different payload sizes. func BenchmarkHttpd(b *testing.B) { - benchmarkHttpdDocSize(b, false /* reverse */) + benchmarkHttpdDocSize(b) } -// BenchmarkReverseHttpd iterates over different sized payloads, testing -// how well the runtime handles receiving different payload sizes. -func BenchmarkReverseHttpd(b *testing.B) { - benchmarkHttpdDocSize(b, true /* reverse */) +// BenchmarkContinuousHttpd runs specific benchmarks for continous jobs. +// The runtime under test is the server serving a runc client. +func BenchmarkContinuousHttpd(b *testing.B) { + sizes := []string{"10Kb", "100Kb", "1Mb"} + threads := []int{1, 25, 100, 1000} + benchmarkHttpdContinuous(b, threads, sizes) } // benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks // for each size. -func benchmarkHttpdDocSize(b *testing.B, reverse bool) { +func benchmarkHttpdDocSize(b *testing.B) { b.Helper() for size, filename := range httpdDocs { concurrency := []int{1, 25, 50, 100, 1000} @@ -64,18 +68,49 @@ func benchmarkHttpdDocSize(b *testing.B, reverse bool) { } b.Run(name, func(b *testing.B) { hey := &tools.Hey{ - Requests: c * b.N, + Requests: b.N, Concurrency: c, Doc: filename, } - runHttpd(b, hey, reverse) + runHttpd(b, hey) + }) + } + } +} + +// benchmarkHttpdContinuous iterates through given sizes and concurrencies. +func benchmarkHttpdContinuous(b *testing.B, concurrency []int, sizes []string) { + for _, size := range sizes { + filename := httpdDocs[size] + for _, c := range concurrency { + fsize := tools.Parameter{ + Name: "filesize", + Value: size, + } + + threads := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + + name, err := tools.ParametersToName(fsize, threads) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } + b.Run(name, func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N, + Concurrency: c, + Doc: filename, + } + runHttpd(b, hey) }) } } } // runHttpd configures the static serving methods to run httpd. -func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) { +func runHttpd(b *testing.B, hey *tools.Hey) { // httpd runs on port 80. port := 80 httpdRunOpts := dockerutil.RunOpts{ @@ -91,5 +126,10 @@ func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) { }, } httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"} - runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse) + runStaticServer(b, httpdRunOpts, httpdCmd, port, hey) +} + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) } diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go index b8ab7dfb8..5e81149fe 100644 --- a/test/benchmarks/network/iperf_test.go +++ b/test/benchmarks/network/iperf_test.go @@ -15,6 +15,7 @@ package network import ( "context" + "os" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -25,16 +26,16 @@ import ( func BenchmarkIperf(b *testing.B) { iperf := tools.Iperf{ - Time: 10, // time in seconds to run client. + Num: b.N, } - clientMachine, err := h.GetMachine() + clientMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } defer clientMachine.CleanUp() - serverMachine, err := h.GetMachine() + serverMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine: %v", err) } @@ -91,23 +92,22 @@ func BenchmarkIperf(b *testing.B) { if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil { b.Fatalf("failed to wait for server: %v", err) } + // Run the client. b.ResetTimer() - - // Restart the server profiles. If the server isn't being profiled - // this does nothing. - server.RestartProfiles() - for i := 0; i < b.N; i++ { - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/iperf", - }, iperf.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("failed to run client: %v", err) - } - b.StopTimer() - iperf.Report(b, out) - b.StartTimer() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + }, iperf.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("failed to run client: %v", err) } + b.StopTimer() + iperf.Report(b, out) }) } } + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go index ce17ddb94..d61002cea 100644 --- a/test/benchmarks/network/network.go +++ b/test/benchmarks/network/network.go @@ -16,16 +16,65 @@ package network import ( - "os" + "context" "testing" + "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" ) -var h harness.Harness +// runStaticServer runs static serving workloads (httpd, nginx). +func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey) { + ctx := context.Background() -// TestMain is the main method for package network. -func TestMain(m *testing.M) { - h.Init() - os.Exit(m.Run()) + // Get two machines: a client and server. + clientMachine, err := harness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := harness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Make the containers. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // Start the server. + if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil { + b.Fatalf("failed to start server: %v", err) + } + + // Get its IP. + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + // Get the published port. + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Make sure the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + + // Run the client. + b.ResetTimer() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + b.StopTimer() + hey.Report(b, out) } diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go index 08565d0b2..74f3578fc 100644 --- a/test/benchmarks/network/nginx_test.go +++ b/test/benchmarks/network/nginx_test.go @@ -14,10 +14,12 @@ package network import ( + "os" "strconv" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) @@ -34,19 +36,21 @@ var nginxDocs = map[string]string{ // BenchmarkNginxDocSize iterates over different sized payloads, testing how // well the runtime handles sending different payload sizes. func BenchmarkNginxDocSize(b *testing.B) { - benchmarkNginxDocSize(b, false /* reverse */, true /* tmpfs */) - benchmarkNginxDocSize(b, false /* reverse */, false /* tmpfs */) + benchmarkNginxDocSize(b, true /* tmpfs */) + benchmarkNginxDocSize(b, false /* tmpfs */) } -// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing -// how well the runtime handles receiving different payload sizes. -func BenchmarkReverseNginxDocSize(b *testing.B) { - benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */) +// BenchmarkContinuousNginx runs specific benchmarks for continous jobs. +// The runtime under test is the sever serving a runc client. +func BenchmarkContinuousNginx(b *testing.B) { + sizes := []string{"10Kb", "100Kb", "1Mb"} + threads := []int{1, 25, 100, 1000} + benchmarkNginxContinuous(b, threads, sizes) } // benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks // for each size. -func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) { +func benchmarkNginxDocSize(b *testing.B, tmpfs bool) { for size, filename := range nginxDocs { concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { @@ -71,21 +75,56 @@ func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) { if err != nil { b.Fatalf("Failed to parse parameters: %v", err) } + b.Run(name, func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N, + Concurrency: c, + Doc: filename, + } + runNginx(b, hey, tmpfs) + }) + } + } +} + +// benchmarkNginxContinuous iterates through given sizes and concurrencies on a tmpfs mount. +func benchmarkNginxContinuous(b *testing.B, concurrency []int, sizes []string) { + for _, size := range sizes { + filename := nginxDocs[size] + for _, c := range concurrency { + fsize := tools.Parameter{ + Name: "filesize", + Value: size, + } + threads := tools.Parameter{ + Name: "concurrency", + Value: strconv.Itoa(c), + } + + fs := tools.Parameter{ + Name: "filesystem", + Value: "tmpfs", + } + + name, err := tools.ParametersToName(fsize, threads, fs) + if err != nil { + b.Fatalf("Failed to parse parameters: %v", err) + } b.Run(name, func(b *testing.B) { hey := &tools.Hey{ - Requests: c * b.N, + Requests: b.N, Concurrency: c, Doc: filename, } - runNginx(b, hey, reverse, tmpfs) + runNginx(b, hey, true /*tmpfs*/) }) } } } // runNginx configures the static serving methods to run httpd. -func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) { +func runNginx(b *testing.B, hey *tools.Hey, tmpfs bool) { // nginx runs on port 80. port := 80 nginxRunOpts := dockerutil.RunOpts{ @@ -99,5 +138,10 @@ func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) { } // Command copies nginxDocs to tmpfs serving directory and runs nginx. - runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse) + runStaticServer(b, nginxRunOpts, nginxCmd, port, hey) +} + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) } diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go index 254538899..a1fc82f95 100644 --- a/test/benchmarks/network/node_test.go +++ b/test/benchmarks/network/node_test.go @@ -15,6 +15,7 @@ package network import ( "context" + "os" "strconv" "testing" "time" @@ -41,7 +42,7 @@ func BenchmarkNode(b *testing.B) { } b.Run(name, func(b *testing.B) { hey := &tools.Hey{ - Requests: b.N * c, // Requests b.N requests per thread. + Requests: b.N, Concurrency: c, } runNode(b, hey) @@ -54,14 +55,14 @@ func runNode(b *testing.B, hey *tools.Hey) { b.Helper() // The machine to hold Redis and the Node Server. - serverMachine, err := h.GetMachine() + serverMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } defer serverMachine.CleanUp() // The machine to run 'hey'. - clientMachine, err := h.GetMachine() + clientMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -116,10 +117,8 @@ func runNode(b *testing.B, hey *tools.Hey) { heyCmd := hey.MakeCmd(servingIP, servingPort) - nodeApp.RestartProfiles() - b.ResetTimer() - // the client should run on Native. + b.ResetTimer() client := clientMachine.GetNativeContainer(ctx, b) out, err := client.Run(ctx, dockerutil.RunOpts{ Image: "benchmarks/hey", @@ -129,7 +128,10 @@ func runNode(b *testing.B, hey *tools.Hey) { } // Stop the timer to parse the data and report stats. - b.StopTimer() hey.Report(b, out) - b.StartTimer() +} + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) } diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go index 0174ff3f3..b7ec16e0a 100644 --- a/test/benchmarks/network/ruby_test.go +++ b/test/benchmarks/network/ruby_test.go @@ -16,6 +16,7 @@ package network import ( "context" "fmt" + "os" "strconv" "testing" "time" @@ -42,7 +43,7 @@ func BenchmarkRuby(b *testing.B) { } b.Run(name, func(b *testing.B) { hey := &tools.Hey{ - Requests: b.N * c, // b.N requests per thread. + Requests: b.N, Concurrency: c, } runRuby(b, hey) @@ -52,16 +53,15 @@ func BenchmarkRuby(b *testing.B) { // runRuby runs the test for a given # of requests and concurrency. func runRuby(b *testing.B, hey *tools.Hey) { - b.Helper() // The machine to hold Redis and the Ruby Server. - serverMachine, err := h.GetMachine() + serverMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } defer serverMachine.CleanUp() // The machine to run 'hey'. - clientMachine, err := h.GetMachine() + clientMachine, err := harness.GetMachine() if err != nil { b.Fatalf("failed to get machine with: %v", err) } @@ -123,10 +123,9 @@ func runRuby(b *testing.B, hey *tools.Hey) { b.Fatalf("failed to wait until serving: %v", err) } heyCmd := hey.MakeCmd(servingIP, servingPort) - rubyApp.RestartProfiles() - b.ResetTimer() // the client should run on Native. + b.ResetTimer() client := clientMachine.GetNativeContainer(ctx, b) defer client.CleanUp(ctx) out, err := client.Run(ctx, dockerutil.RunOpts{ @@ -135,9 +134,11 @@ func runRuby(b *testing.B, hey *tools.Hey) { if err != nil { b.Fatalf("hey container failed: %v logs: %s", err, out) } - - // Stop the timer to parse the data and report stats. b.StopTimer() hey.Report(b, out) - b.StartTimer() +} + +func TestMain(m *testing.M) { + harness.Init() + os.Exit(m.Run()) } diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go deleted file mode 100644 index e747a1395..000000000 --- a/test/benchmarks/network/static_server.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package network - -import ( - "context" - "testing" - - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/benchmarks/harness" - "gvisor.dev/gvisor/test/benchmarks/tools" -) - -// runStaticServer runs static serving workloads (httpd, nginx). -func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) { - ctx := context.Background() - - // Get two machines: a client and server. - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer serverMachine.CleanUp() - - // Make the containers. 'reverse=true' specifies that the client should use the - // runtime under test. - var client, server *dockerutil.Container - if reverse { - client = clientMachine.GetContainer(ctx, b) - server = serverMachine.GetNativeContainer(ctx, b) - } else { - client = clientMachine.GetNativeContainer(ctx, b) - server = serverMachine.GetContainer(ctx, b) - } - defer client.CleanUp(ctx) - defer server.CleanUp(ctx) - - // Start the server. - if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil { - b.Fatalf("failed to start server: %v", err) - } - - // Get its IP. - ip, err := serverMachine.IPAddress() - if err != nil { - b.Fatalf("failed to find server ip: %v", err) - } - - // Get the published port. - servingPort, err := server.FindPort(ctx, port) - if err != nil { - b.Fatalf("failed to find server port %d: %v", port, err) - } - - // Make sure the server is serving. - harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) - b.ResetTimer() - server.RestartProfiles() - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, hey.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("run failed with: %v", err) - } - - b.StopTimer() - hey.Report(b, out) - b.StartTimer() -} diff --git a/test/benchmarks/tools/fio.go b/test/benchmarks/tools/fio.go index f5f60fa84..f6324c3ab 100644 --- a/test/benchmarks/tools/fio.go +++ b/test/benchmarks/tools/fio.go @@ -25,25 +25,20 @@ import ( // Fio makes 'fio' commands and parses their output. type Fio struct { Test string // test to run: read, write, randread, randwrite. - Size string // total size to be read/written of format N[GMK] (e.g. 5G). - Blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K). - Iodepth int // iodepth for reads/writes. - Time int // time to run the test in seconds, usually for rand(read/write). + Size int // total size to be read/written in megabytes. + BlockSize int // block size to be read/written in kilobytes. + IODepth int // I/O depth for reads/writes. } // MakeCmd makes a 'fio' command. func (f *Fio) MakeCmd(filename string) []string { cmd := []string{"fio", "--output-format=json", "--ioengine=sync"} cmd = append(cmd, fmt.Sprintf("--name=%s", f.Test)) - cmd = append(cmd, fmt.Sprintf("--size=%s", f.Size)) - cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.Blocksize)) + cmd = append(cmd, fmt.Sprintf("--size=%dM", f.Size)) + cmd = append(cmd, fmt.Sprintf("--blocksize=%dK", f.BlockSize)) cmd = append(cmd, fmt.Sprintf("--filename=%s", filename)) - cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.Iodepth)) + cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.IODepth)) cmd = append(cmd, fmt.Sprintf("--rw=%s", f.Test)) - if f.Time != 0 { - cmd = append(cmd, "--time_based") - cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.Time)) - } return cmd } diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go index b8cb938fe..de908feeb 100644 --- a/test/benchmarks/tools/hey.go +++ b/test/benchmarks/tools/hey.go @@ -19,7 +19,6 @@ import ( "net" "regexp" "strconv" - "strings" "testing" ) @@ -32,8 +31,16 @@ type Hey struct { // MakeCmd returns a 'hey' command. func (h *Hey) MakeCmd(ip net.IP, port int) []string { - return strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/%s", - h.Requests, h.Concurrency, ip, port, h.Doc), " ") + c := h.Concurrency + if c > h.Requests { + c = h.Requests + } + return []string{ + "hey", + "-n", fmt.Sprintf("%d", h.Requests), + "-c", fmt.Sprintf("%d", c), + fmt.Sprintf("http://%s:%d/%s", ip.String(), port, h.Doc), + } } // Report parses output from 'hey' and reports metrics. diff --git a/test/benchmarks/tools/iperf.go b/test/benchmarks/tools/iperf.go index 5c4e7125b..abf296731 100644 --- a/test/benchmarks/tools/iperf.go +++ b/test/benchmarks/tools/iperf.go @@ -19,19 +19,27 @@ import ( "net" "regexp" "strconv" - "strings" "testing" ) +const length = 64 * 1024 + // Iperf is for the client side of `iperf`. type Iperf struct { - Time int + Num int } // MakeCmd returns a iperf client command. func (i *Iperf) MakeCmd(ip net.IP, port int) []string { - // iperf report in Kb realtime - return strings.Split(fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", i.Time, ip, port), " ") + return []string{ + "iperf", + "--format", "K", // Output in KBytes. + "--realtime", // Measured in realtime. + "--num", fmt.Sprintf("%d", i.Num), + "--length", fmt.Sprintf("%d", length), + "--client", ip.String(), + "--port", fmt.Sprintf("%d", port), + } } // Report parses output from iperf client and reports metrics. @@ -42,6 +50,7 @@ func (i *Iperf) Report(b *testing.B, output string) { if err != nil { b.Fatalf("failed to parse bandwitdth from %s: %v", output, err) } + b.SetBytes(length) // Measure Bytes/sec for b.N, although below is iperf output. ReportCustomMetric(b, bW*1024, "bandwidth" /*metric name*/, "bytes_per_second" /*unit*/) } diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go index e35886437..12fdbc7cc 100644 --- a/test/benchmarks/tools/redis.go +++ b/test/benchmarks/tools/redis.go @@ -19,7 +19,6 @@ import ( "net" "regexp" "strconv" - "strings" "testing" ) @@ -29,17 +28,29 @@ type Redis struct { } // MakeCmd returns a redis-benchmark client command. -func (r *Redis) MakeCmd(ip net.IP, port int) []string { +func (r *Redis) MakeCmd(ip net.IP, port, requests int) []string { // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case. // Note that "ping" will run both PING_INLINE and PING_BULK. if r.Operation == "PING_BULK" { - return strings.Split( - fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, port), " ") + return []string{ + "redis-benchmark", + "--csv", + "-t", "ping", + "-h", ip.String(), + "-p", fmt.Sprintf("%d", port), + "-n", fmt.Sprintf("%d", requests), + } } // runs redis-benchmark -t operation for 100K requests against server. - return strings.Split( - fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", r.Operation, ip, port), " ") + return []string{ + "redis-benchmark", + "--csv", + "-t", r.Operation, + "-h", ip.String(), + "-p", fmt.Sprintf("%d", port), + "-n", fmt.Sprintf("%d", requests), + } } // Report parses output from redis-benchmark client and reports metrics. diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go index 7ccacd8ff..2b8e6c8aa 100644 --- a/test/benchmarks/tools/sysbench.go +++ b/test/benchmarks/tools/sysbench.go @@ -18,58 +18,46 @@ import ( "fmt" "regexp" "strconv" - "strings" "testing" ) -var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&" - // Sysbench represents a 'sysbench' command. type Sysbench interface { - MakeCmd() []string // Makes a sysbench command. - flags() []string - Report(*testing.B, string) // Reports results contained in string. + // MakeCmd constructs the relevant command line. + MakeCmd(*testing.B) []string + + // Report reports relevant custom metrics. + Report(*testing.B, string) } // SysbenchBase is the top level struct for sysbench and holds top-level arguments // for sysbench. See: 'sysbench --help' type SysbenchBase struct { - Threads int // number of Threads for the test. - Time int // time limit for test in seconds. + // Threads is the number of threads for the test. + Threads int } // baseFlags returns top level flags. -func (s *SysbenchBase) baseFlags() []string { +func (s *SysbenchBase) baseFlags(b *testing.B) []string { var ret []string if s.Threads > 0 { ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads)) } - if s.Time > 0 { - ret = append(ret, fmt.Sprintf("--time=%d", s.Time)) - } + ret = append(ret, "--time=0") // Ensure events is used. + ret = append(ret, fmt.Sprintf("--events=%d", b.N)) return ret } // SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments. type SysbenchCPU struct { - Base SysbenchBase - MaxPrime int // upper limit for primes generator [10000]. + SysbenchBase } // MakeCmd makes commands for SysbenchCPU. -func (s *SysbenchCPU) MakeCmd() []string { - cmd := []string{warmup, "sysbench"} - cmd = append(cmd, s.flags()...) - cmd = append(cmd, "cpu run") - return []string{"sh", "-c", strings.Join(cmd, " ")} -} - -// flags makes flags for SysbenchCPU cmds. -func (s *SysbenchCPU) flags() []string { - cmd := s.Base.baseFlags() - if s.MaxPrime > 0 { - return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime)) - } +func (s *SysbenchCPU) MakeCmd(b *testing.B) []string { + cmd := []string{"sysbench"} + cmd = append(cmd, s.baseFlags(b)...) + cmd = append(cmd, "cpu", "run") return cmd } @@ -96,9 +84,9 @@ func (s *SysbenchCPU) parseEvents(data string) (float64, error) { // SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments. type SysbenchMemory struct { - Base SysbenchBase - BlockSize string // size of test memory block [1K]. - TotalSize string // size of data to transfer [100G]. + SysbenchBase + BlockSize int // size of test memory block in megabytes [1]. + TotalSize int // size of data to transfer in gigabytes [100]. Scope string // memory access scope {global, local} [global]. HugeTLB bool // allocate memory from HugeTLB [off]. OperationType string // type of memory ops {read, write, none} [write]. @@ -106,21 +94,21 @@ type SysbenchMemory struct { } // MakeCmd makes commands for SysbenchMemory. -func (s *SysbenchMemory) MakeCmd() []string { - cmd := []string{warmup, "sysbench"} - cmd = append(cmd, s.flags()...) - cmd = append(cmd, "memory run") - return []string{"sh", "-c", strings.Join(cmd, " ")} +func (s *SysbenchMemory) MakeCmd(b *testing.B) []string { + cmd := []string{"sysbench"} + cmd = append(cmd, s.flags(b)...) + cmd = append(cmd, "memory", "run") + return cmd } // flags makes flags for SysbenchMemory cmds. -func (s *SysbenchMemory) flags() []string { - cmd := s.Base.baseFlags() - if s.BlockSize != "" { - cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize)) +func (s *SysbenchMemory) flags(b *testing.B) []string { + cmd := s.baseFlags(b) + if s.BlockSize != 0 { + cmd = append(cmd, fmt.Sprintf("--memory-block-size=%dM", s.BlockSize)) } - if s.TotalSize != "" { - cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize)) + if s.TotalSize != 0 { + cmd = append(cmd, fmt.Sprintf("--memory-total-size=%dG", s.TotalSize)) } if s.Scope != "" { cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope)) @@ -147,7 +135,7 @@ func (s *SysbenchMemory) Report(b *testing.B, output string) { ReportCustomMetric(b, result, "memory_operations" /*metric name*/, "ops_per_second" /*unit*/) } -var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`) +var memoryOperationsRE = regexp.MustCompile(`Total\s+operations:\s+\d+\s+\((\s*\d+\.\d+\s*)\s+per\s+second\)`) // parseOperations parses memory operations per second form sysbench memory ouput. func (s *SysbenchMemory) parseOperations(data string) (float64, error) { @@ -160,24 +148,24 @@ func (s *SysbenchMemory) parseOperations(data string) (float64, error) { // SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments. type SysbenchMutex struct { - Base SysbenchBase + SysbenchBase Num int // total size of mutex array [4096]. - Locks int // number of mutex locks per thread [50K]. - Loops int // number of loops to do outside mutex lock [10K]. + Locks int // number of mutex locks per thread [50000]. + Loops int // number of loops to do outside mutex lock [10000]. } // MakeCmd makes commands for SysbenchMutex. -func (s *SysbenchMutex) MakeCmd() []string { - cmd := []string{warmup, "sysbench"} - cmd = append(cmd, s.flags()...) - cmd = append(cmd, "mutex run") - return []string{"sh", "-c", strings.Join(cmd, " ")} +func (s *SysbenchMutex) MakeCmd(b *testing.B) []string { + cmd := []string{"sysbench"} + cmd = append(cmd, s.flags(b)...) + cmd = append(cmd, "mutex", "run") + return cmd } // flags makes flags for SysbenchMutex commands. -func (s *SysbenchMutex) flags() []string { +func (s *SysbenchMutex) flags(b *testing.B) []string { var cmd []string - cmd = append(cmd, s.Base.baseFlags()...) + cmd = append(cmd, s.baseFlags(b)...) if s.Num > 0 { cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num)) } diff --git a/test/cmd/test_app/fds.go b/test/cmd/test_app/fds.go index a7658eefd..d4354f0d3 100644 --- a/test/cmd/test_app/fds.go +++ b/test/cmd/test_app/fds.go @@ -16,6 +16,7 @@ package main import ( "context" + "io" "io/ioutil" "log" "os" @@ -168,8 +169,8 @@ func (fdr *fdReceiver) Execute(ctx context.Context, f *flag.FlagSet, args ...int file := os.NewFile(uintptr(fd), "received file") defer file.Close() - if _, err := file.Seek(0, os.SEEK_SET); err != nil { - log.Fatalf("Seek(0, 0) failed: %v", err) + if _, err := file.Seek(0, io.SeekStart); err != nil { + log.Fatalf("Error from seek(0, 0): %v", err) } got, err := ioutil.ReadAll(file) diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 03bdfa889..d07ed6ba5 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -260,12 +260,10 @@ func TestMemLimit(t *testing.T) { d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) - // N.B. Because the size of the memory file may grow in large chunks, - // there is a minimum threshold of 1GB for the MemTotal figure. - allocMemory := 1024 * 1024 // In kb. + allocMemoryKb := 50 * 1024 out, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", - Memory: allocMemory * 1024, // In bytes. + Memory: allocMemoryKb * 1024, // In bytes. }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'") if err != nil { t.Fatalf("docker run failed: %v", err) @@ -285,7 +283,7 @@ func TestMemLimit(t *testing.T) { if err != nil { t.Fatalf("failed to parse %q: %v", out, err) } - if want := uint64(allocMemory); got != want { + if want := uint64(allocMemoryKb); got != want { t.Errorf("MemTotal got: %d, want: %d", got, want) } } diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go index 70bbe5121..84564cdaa 100644 --- a/test/e2e/regression_test.go +++ b/test/e2e/regression_test.go @@ -35,7 +35,7 @@ func TestBindOverlay(t *testing.T) { // Run the container. got, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/ubuntu", - }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p") + }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p") if err != nil { t.Fatalf("docker run failed: %v", err) } diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD index d1fb178e8..2f745bd47 100644 --- a/test/fuse/linux/BUILD +++ b/test/fuse/linux/BUILD @@ -235,6 +235,7 @@ cc_binary( srcs = ["mount_test.cc"], deps = [ gtest, + "//test/util:mount_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", diff --git a/test/fuse/linux/mount_test.cc b/test/fuse/linux/mount_test.cc index a5c2fbb01..8a5478116 100644 --- a/test/fuse/linux/mount_test.cc +++ b/test/fuse/linux/mount_test.cc @@ -17,6 +17,7 @@ #include <sys/mount.h> #include "gtest/gtest.h" +#include "test/util/mount_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -25,6 +26,17 @@ namespace testing { namespace { +TEST(FuseMount, Success) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); + std::string mopts = absl::StrCat("fd=", std::to_string(fd.get())); + + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + const auto mount = + ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); +} + TEST(FuseMount, FDNotParsable) { int devfd; EXPECT_THAT(devfd = open("/dev/fuse", O_RDWR), SyscallSucceeds()); @@ -35,6 +47,36 @@ TEST(FuseMount, FDNotParsable) { SyscallFailsWithErrno(EINVAL)); } +TEST(FuseMount, NoDevice) { + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FuseMount, ClosedFD) { + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); + int fd = f.release(); + close(fd); + std::string mopts = absl::StrCat("fd=", std::to_string(fd)); + + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(FuseMount, BadFD) { + const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + std::string mopts = absl::StrCat("fd=", std::to_string(fd.get())); + + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + } // namespace } // namespace testing diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index d3e5efd4f..f4af45e96 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -248,7 +248,7 @@ func (FilterOutputOwnerFail) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { - return fmt.Errorf("Invalid argument") + return fmt.Errorf("invalid argument") } return nil diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index 49642f282..5d95516ee 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -38,6 +38,15 @@ packetdrill_test( scripts = ["tcp_defer_accept_timeout.pkt"], ) +test_suite( + name = "all_tests", + tags = [ + "manual", + "packetdrill", + ], + tests = existing_rules(), +) + bzl_library( name = "defs_bzl", srcs = ["defs.bzl"], diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index 8be2c6526..3e26c73cb 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -162,7 +162,7 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut Image: "packetimpact", CapAdd: []string{"NET_ADMIN"}, } - if _, err := mountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil { + if _, err := MountTempDirectory(t, &runOpts, "dut-output", testOutputDir); err != nil { return dutInfo{}, err } @@ -228,7 +228,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co Image: "packetimpact", CapAdd: []string{"NET_ADMIN"}, } - if _, err := mountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil { + if _, err := MountTempDirectory(t, &runOpts, "testbench-output", testOutputDir); err != nil { t.Fatal(err) } tbb := path.Base(testbenchBinary) @@ -565,11 +565,11 @@ func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerut return nil } -// mountTempDirectory creates a temporary directory on host with the template +// MountTempDirectory creates a temporary directory on host with the template // and then mounts it into the container under the name provided. The temporary // directory name is returned. Content in that directory will be copied to // TEST_UNDECLARED_OUTPUTS_DIR in cleanup phase. -func mountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) { +func MountTempDirectory(t *testing.T, runOpts *dockerutil.RunOpts, hostDirTemplate, containerDir string) (string, error) { t.Helper() tmpDir, err := ioutil.TempDir("", hostDirTemplate) if err != nil { diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 50b9ccf68..576577310 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -306,11 +306,11 @@ func (s *tcpState) incoming(received Layer) Layer { if s.remoteSeqNum != nil { newIn.SeqNum = Uint32(uint32(*s.remoteSeqNum)) } - if s.localSeqNum != nil && (*tcpReceived.Flags&header.TCPFlagAck) != 0 { + if seq, flags := s.localSeqNum, tcpReceived.Flags; seq != nil && flags != nil && *flags&header.TCPFlagAck != 0 { // The caller didn't specify an AckNum so we'll expect the calculated one, // but only if the ACK flag is set because the AckNum is not valid in a // header if ACK is not set. - newIn.AckNum = Uint32(uint32(*s.localSeqNum)) + newIn.AckNum = Uint32(uint32(*seq)) } return &newIn } @@ -603,9 +603,9 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du } if gotLayers == nil { if errs == nil { - return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) + return nil, fmt.Errorf("got no frames matching %s during %s", layers, timeout) } - return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout) + return nil, fmt.Errorf("got frames:\n%w want %s during %s", errs, layers, timeout) } if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { @@ -615,7 +615,12 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du } return gotLayers, nil } - errs = multierr.Combine(errs, &layersError{got: gotLayers, want: conn.incoming(gotLayers)}) + want := conn.incoming(layers) + if err := want.merge(layers); err != nil { + errs = multierr.Combine(errs, err) + } else { + errs = multierr.Combine(errs, &layersError{got: gotLayers, want: want}) + } } } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index dcff4ab36..19e6b8d7d 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -505,13 +505,13 @@ func (l *IPv6) ToBytes() ([]byte, error) { } } if l.NextHeader != nil { - fields.NextHeader = *l.NextHeader + fields.TransportProtocol = tcpip.TransportProtocolNumber(*l.NextHeader) } else { nh, err := nextHeaderByLayer(l.next()) if err != nil { return nil, err } - fields.NextHeader = nh + fields.TransportProtocol = tcpip.TransportProtocolNumber(nh) } if l.HopLimit != nil { fields.HopLimit = *l.HopLimit diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 373ab8d2f..b1b3c578b 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -383,3 +383,12 @@ validate_all_tests() expect_netstack_failure = hasattr(t, "expect_netstack_failure"), num_duts = t.num_duts if hasattr(t, "num_duts") else 1, ) for t in ALL_TESTS] + +test_suite( + name = "all_tests", + tags = [ + "manual", + "packetimpact", + ], + tests = existing_rules(), +) diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go index cf0431c57..d06690705 100644 --- a/test/packetimpact/tests/tcp_zero_receive_window_test.go +++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go @@ -15,7 +15,6 @@ package tcp_zero_receive_window_test import ( - "context" "flag" "fmt" "testing" @@ -46,16 +45,30 @@ func TestZeroReceiveWindow(t *testing.T) { dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) - samplePayload := &testbench.Payload{Bytes: make([]byte, payloadLen)} //testbench.GenerateRandomPayload(t, payloadLen)} - // Expect the DUT to eventually advertize zero receive window. + samplePayload := &testbench.Payload{Bytes: testbench.GenerateRandomPayload(t, payloadLen)} + // Expect the DUT to eventually advertise zero receive window. // The test would timeout otherwise. - for { + for readOnce := false; ; { conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected packet was not received: %s", err) } - if *gotTCP.WindowSize == 0 { + // Read once to trigger the subsequent window update from the + // DUT to grow the right edge of the receive window from what + // was advertised in the SYN-ACK. This ensures that we test + // for the full default buffer size (1MB on gVisor at the time + // of writing this comment), thus testing for cases when the + // scaled receive window size ends up > 65535 (0xffff). + if !readOnce { + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) + } + readOnce = true + } + windowSize := *gotTCP.WindowSize + t.Logf("got window size = %d", windowSize) + if windowSize == 0 { break } } @@ -92,10 +105,9 @@ func TestNonZeroReceiveWindow(t *testing.T) { if err != nil { t.Fatalf("expected packet was not received: %s", err) } - if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(payloadLen), 0); ret == -1 { - t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, payloadLen, ret, err) + if got := dut.Recv(t, acceptFd, int32(payloadLen), 0); len(got) != payloadLen { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %d, want %d", acceptFd, payloadLen, len(got), payloadLen) } - if *gotTCP.WindowSize == 0 { t.Fatalf("expected non-zero receive window.") } diff --git a/test/root/BUILD b/test/root/BUILD index a9130b34f..8d9fff578 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -1,5 +1,4 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/vm:defs.bzl", "vm_test") package(licenses = ["notice"]) @@ -24,12 +23,8 @@ go_test( ], library = ":root", tags = [ - # Requires docker and runsc to be configured before the test runs. - # Also, the test needs to be run as root. Note that below, the - # root_vm_test relies on the default runtime 'runsc' being installed by - # the default installer. - "manual", "local", + "manual", ], visibility = ["//:sandbox"], deps = [ @@ -46,10 +41,3 @@ go_test( "@org_golang_x_sys//unix:go_default_library", ], ) - -vm_test( - name = "root_vm_test", - size = "large", - shard_count = 1, - targets = [":root_test"], -) diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go index 9272137ff..f2db5f9ea 100644 --- a/test/runtimes/runner/lib/lib.go +++ b/test/runtimes/runner/lib/lib.go @@ -196,3 +196,4 @@ func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } func (f testDeps) ImportPath() string { return "" } func (f testDeps) StartTestLog(io.Writer) {} func (f testDeps) StopTestLog() error { return nil } +func (f testDeps) SetPanicOnExit0(bool) {} diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 135d58ae6..a5b9233f7 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -647,6 +647,7 @@ syscall_test( syscall_test( size = "medium", + add_hostinet = True, test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test", ) @@ -658,12 +659,14 @@ syscall_test( syscall_test( size = "medium", + add_hostinet = True, shard_count = most_shards, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", ) syscall_test( size = "medium", + add_hostinet = True, test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test", ) @@ -680,6 +683,7 @@ syscall_test( syscall_test( size = "medium", + add_hostinet = True, shard_count = more_shards, # Takes too long under gotsan to run. tags = ["nogotsan"], @@ -728,6 +732,7 @@ syscall_test( ) syscall_test( + add_hostinet = True, test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", ) @@ -903,6 +908,7 @@ syscall_test( ) syscall_test( + add_hostinet = True, test = "//test/syscalls/linux:udp_bind_test", ) @@ -967,6 +973,7 @@ syscall_test( ) syscall_test( + add_hostinet = True, test = "//test/syscalls/linux:proc_net_tcp_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index a9e0b070a..4e0c8a574 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -21,6 +21,7 @@ exports_files( "socket_ip_unbound.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", + "socket_ipv6_udp_unbound_loopback.cc", "socket_ipv4_udp_unbound_loopback_nogotsan.cc", "tcp_socket.cc", "udp_bind.cc", @@ -944,6 +945,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", gtest, @@ -2451,6 +2453,27 @@ cc_library( ) cc_library( + name = "socket_ipv6_udp_unbound_test_cases", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound.cc", + ], + hdrs = [ + "socket_ipv6_udp_unbound.h", + ], + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + "@com_google_absl//absl/memory", + gtest, + "//test/util:posix_error", + "//test/util:save_util", + "//test/util:test_util", + ], + alwayslink = 1, +) + +cc_library( name = "socket_ipv4_udp_unbound_netlink_test_cases", testonly = 1, srcs = [ @@ -2790,6 +2813,22 @@ cc_binary( ) cc_binary( + name = "socket_ipv6_udp_unbound_loopback_test", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_loopback.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv6_udp_unbound_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_ipv4_udp_unbound_loopback_nogotsan_test", testonly = 1, srcs = [ @@ -3286,6 +3325,7 @@ cc_binary( ":socket_test_util", ":unix_domain_socket_test_util", gtest, + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", ], diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index b040cdcf7..93c692dd6 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -32,6 +32,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_map.h" #include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -381,7 +382,7 @@ TYPED_TEST(GetdentsTest, PartialBuffer) { // getdents iterates correctly despite mutation of /proc/self/fd. TYPED_TEST(GetdentsTest, ProcSelfFd) { constexpr size_t kNfds = 10; - std::unordered_map<int, FileDescriptor> fds; + absl::node_hash_map<int, FileDescriptor> fds; fds.reserve(kNfds); for (size_t i = 0; i < kNfds; i++) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()); diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index d65b7d031..15b645fb7 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -345,42 +345,6 @@ TEST(MountTest, RenameRemoveMountPoint) { ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } -TEST(MountTest, MountFuseFilesystemNoDevice) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // Before kernel version 4.16-rc6, FUSE mount is protected by - // capable(CAP_SYS_ADMIN). After this version, it uses - // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not - // allowed to mount fuse file systems without the global CAP_SYS_ADMIN. - int res = mount("", dir.path().c_str(), "fuse", 0, ""); - SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); - - EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), - SyscallFailsWithErrno(EINVAL)); -} - -TEST(MountTest, MountFuseFilesystem) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); - std::string mopts = "fd=" + std::to_string(fd.get()); - - auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - - // See comments in MountFuseFilesystemNoDevice for the reason why we skip - // EPERM when running on Linux. - int res = mount("", dir.path().c_str(), "fuse", 0, ""); - SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); - - auto const mount = - ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); -} - } // namespace } // namespace testing diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 77f390f3c..fcd162ca2 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -505,6 +505,18 @@ TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) { EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(OpenTest, OpenWithStrangeFlags) { + // VFS1 incorrectly allows read/write operations on such file descriptors. + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY | O_RDWR)); + EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF)); + char c; + EXPECT_THAT(read(fd.get(), &c, 1), SyscallFailsWithErrno(EBADF)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 06d9dbf65..01ccbdcd2 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -71,13 +71,13 @@ class PipeTest : public ::testing::TestWithParam<PipeCreator> { // Returns true iff the pipe represents a named pipe. bool IsNamedPipe() const { return named_pipe_; } - int Size() const { + size_t Size() const { int s1 = fcntl(rfd_.get(), F_GETPIPE_SZ); int s2 = fcntl(wfd_.get(), F_GETPIPE_SZ); EXPECT_GT(s1, 0); EXPECT_GT(s2, 0); EXPECT_EQ(s1, s2); - return s1; + return static_cast<size_t>(s1); } static void TearDownTestSuite() { @@ -568,7 +568,7 @@ TEST_P(PipeTest, Streaming) { DisableSave ds; // Size() requires 2 syscalls, call it once and remember the value. - const int pipe_size = Size(); + const size_t pipe_size = Size(); const size_t streamed_bytes = 4 * pipe_size; absl::Notification notify; @@ -576,7 +576,7 @@ TEST_P(PipeTest, Streaming) { std::vector<char> buf(1024); // Don't start until it's full. notify.WaitForNotification(); - ssize_t total = 0; + size_t total = 0; while (total < streamed_bytes) { ASSERT_THAT(read(rfd_.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -593,7 +593,7 @@ TEST_P(PipeTest, Streaming) { // page) for the check for notify.Notify() below to be correct. std::vector<char> buf(1024); RandomizeBuffer(buf.data(), buf.size()); - ssize_t total = 0; + size_t total = 0; while (total < streamed_bytes) { ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 575be014c..e508ce27f 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1802,6 +1802,33 @@ TEST(ProcPidCmdline, SubprocessForkSameCmdline) { } } +TEST(ProcPidCmdline, SubprocessSeekCmdline) { + FileDescriptor fd; + ASSERT_NO_ERRNO(WithSubprocess( + [&](int pid) -> PosixError { + // Running. Open /proc/pid/cmdline. + ASSIGN_OR_RETURN_ERRNO( + fd, Open(absl::StrCat("/proc/", pid, "/cmdline"), O_RDONLY)); + return NoError(); + }, + [&](int pid) -> PosixError { + // Zombie, but seek should still succeed. + int ret = lseek(fd.get(), 0x801, 0); + if (ret < 0) { + return PosixError(errno); + } + return NoError(); + }, + [&](int pid) -> PosixError { + // Exited. + int ret = lseek(fd.get(), 0x801, 0); + if (ret < 0) { + return PosixError(errno); + } + return NoError(); + })); +} + // Test whether /proc/PID/ symlinks can be read for a running process. TEST(ProcPidSymlink, SubprocessRunning) { char buf[1]; diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 23677e296..1cc700fe7 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -420,14 +420,14 @@ TEST(ProcNetSnmp, CheckNetStat) { int name_count = 0; int value_count = 0; std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); - for (int i = 0; i + 1 < lines.size(); i += 2) { + for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) { std::vector<absl::string_view> names = absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); std::vector<absl::string_view> values = absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] << "' and '" << lines[i + 1] << "'"; - for (int j = 0; j < names.size() && j < values.size(); ++j) { + for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) { if (names[j] == "TCPOrigDataSent" || names[j] == "TCPSynRetrans" || names[j] == "TCPDSACKRecv" || names[j] == "TCPDSACKOfoRecv") { ++name_count; @@ -457,14 +457,14 @@ TEST(ProcNetSnmp, CheckSnmp) { int name_count = 0; int value_count = 0; std::vector<absl::string_view> lines = absl::StrSplit(contents, '\n'); - for (int i = 0; i + 1 < lines.size(); i += 2) { + for (long unsigned int i = 0; i + 1 < lines.size(); i += 2) { std::vector<absl::string_view> names = absl::StrSplit(lines[i], absl::ByAnyChar("\t ")); std::vector<absl::string_view> values = absl::StrSplit(lines[i + 1], absl::ByAnyChar("\t ")); EXPECT_EQ(names.size(), values.size()) << " mismatch in lines '" << lines[i] << "' and '" << lines[i + 1] << "'"; - for (int j = 0; j < names.size() && j < values.size(); ++j) { + for (long unsigned int j = 0; j < names.size() && j < values.size(); ++j) { if (names[j] == "RetransSegs") { ++name_count; int64_t val; diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index a63067586..662c6feb2 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -181,7 +181,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { // Returns true on match, and sets 'match' to point to the matching entry. bool FindBy(std::vector<UnixEntry> entries, UnixEntry* match, std::function<bool(const UnixEntry&)> predicate) { - for (int i = 0; i < entries.size(); ++i) { + for (long unsigned int i = 0; i < entries.size(); ++i) { if (predicate(entries[i])) { *match = entries[i]; return true; diff --git a/test/syscalls/linux/proc_pid_uid_gid_map.cc b/test/syscalls/linux/proc_pid_uid_gid_map.cc index 748f7be58..af052a63c 100644 --- a/test/syscalls/linux/proc_pid_uid_gid_map.cc +++ b/test/syscalls/linux/proc_pid_uid_gid_map.cc @@ -203,7 +203,8 @@ TEST_P(ProcSelfUidGidMapTest, IdentityMapOwnID) { EXPECT_THAT( InNewUserNamespaceWithMapFD([&](int fd) { DenySelfSetgroups(); - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); + TEST_PCHECK(static_cast<long unsigned int>( + write(fd, line.c_str(), line.size())) == line.size()); }), IsPosixErrorOkAndHolds(0)); } @@ -220,7 +221,8 @@ TEST_P(ProcSelfUidGidMapTest, TrailingNewlineAndNULIgnored) { DenySelfSetgroups(); // The write should return the full size of the write, even though // characters after the NUL were ignored. - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); + TEST_PCHECK(static_cast<long unsigned int>( + write(fd, line.c_str(), line.size())) == line.size()); }), IsPosixErrorOkAndHolds(0)); } @@ -233,7 +235,8 @@ TEST_P(ProcSelfUidGidMapTest, NonIdentityMapOwnID) { EXPECT_THAT( InNewUserNamespaceWithMapFD([&](int fd) { DenySelfSetgroups(); - TEST_PCHECK(write(fd, line.c_str(), line.size()) == line.size()); + TEST_PCHECK(static_cast<long unsigned int>( + write(fd, line.c_str(), line.size())) == line.size()); }), IsPosixErrorOkAndHolds(0)); } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 54709371c..955bcee4b 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -852,6 +852,51 @@ TEST(RawSocketTest, IPv6ProtoRaw) { SyscallFailsWithErrno(EINVAL)); } +TEST(RawSocketTest, IPv6SendMsg) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP), + SyscallSucceeds()); + + char kBuf[] = "hello"; + struct iovec iov = {}; + iov.iov_base = static_cast<void*>(const_cast<char*>(kBuf)); + iov.iov_len = static_cast<size_t>(sizeof(kBuf)); + + struct sockaddr_storage addr = {}; + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + struct msghdr msg = {}; + msg.msg_name = static_cast<void*>(&addr); + msg.msg_namelen = sizeof(sockaddr_in); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(RawSocketTest, ConnectOnIPv6Socket) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_TCP), + SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + ASSERT_THAT(connect(sock, reinterpret_cast<struct sockaddr*>(&addr), + sizeof(sockaddr_in6)), + SyscallFailsWithErrno(EAFNOSUPPORT)); +} + INSTANTIATE_TEST_SUITE_P( AllInetTests, RawSocketTest, ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 1c1bf6a57..0530fce44 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -20,6 +20,7 @@ #include <atomic> #include <cerrno> #include <ctime> +#include <set> #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -35,6 +36,17 @@ namespace gvisor { namespace testing { namespace { +constexpr int kSemMap = 1024000000; +constexpr int kSemMni = 32000; +constexpr int kSemMns = 1024000000; +constexpr int kSemMnu = 1024000000; +constexpr int kSemMsl = 32000; +constexpr int kSemOpm = 500; +constexpr int kSemUme = 500; +constexpr int kSemUsz = 20; +constexpr int kSemVmx = 32767; +constexpr int kSemAem = 32767; + class AutoSem { public: explicit AutoSem(int id) : id_(id) {} @@ -586,7 +598,7 @@ TEST(SemaphoreTest, SemopGetzcnt) { buf.sem_num = 0; buf.sem_op = 0; constexpr size_t kLoops = 10; - for (auto i = 0; i < kLoops; i++) { + for (size_t i = 0; i < kLoops; i++) { auto child_pid = fork(); if (child_pid == 0) { TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); @@ -693,7 +705,7 @@ TEST(SemaphoreTest, SemopGetncnt) { buf.sem_num = 0; buf.sem_op = -1; constexpr size_t kLoops = 10; - for (auto i = 0; i < kLoops; i++) { + for (size_t i = 0; i < kLoops; i++) { auto child_pid = fork(); if (child_pid == 0) { TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); @@ -774,18 +786,148 @@ TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) { } TEST(SemaphoreTest, IpcInfo) { + constexpr int kLoops = 5; + std::set<int> sem_ids; struct seminfo info; + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + for (int i = 0; i < kLoops; i++) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + sem_ids.insert(sem.release()); + } + ASSERT_EQ(sem_ids.size(), kLoops); + + int max_used_index = 0; + EXPECT_THAT(max_used_index = semctl(0, 0, IPC_INFO, &info), + SyscallSucceeds()); + + std::set<int> sem_ids_before_max_index; + for (int i = 0; i <= max_used_index; i++) { + struct semid_ds ds = {}; + int sem_id = semctl(i, 0, SEM_STAT, &ds); + // Only if index i is used within the registry. + if (sem_ids.find(sem_id) != sem_ids.end()) { + struct semid_ds ipc_stat_ds; + ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key); + EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid); + EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid); + EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid); + EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid); + EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode); + EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime); + EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime); + EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems); + + // Remove the semaphore set's read permission. + struct semid_ds ipc_set_ds; + ipc_set_ds.sem_perm.uid = getuid(); + ipc_set_ds.sem_perm.gid = getgid(); + // Keep the semaphore set's write permission so that it could be removed. + ipc_set_ds.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds()); + ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES)); + + sem_ids_before_max_index.insert(sem_id); + } + } + EXPECT_EQ(sem_ids_before_max_index.size(), kLoops); + for (const int sem_id : sem_ids) { + ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds()); + } + ASSERT_THAT(semctl(0, 0, IPC_INFO, &info), SyscallSucceeds()); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + EXPECT_EQ(info.semusz, kSemUsz); + EXPECT_EQ(info.semvmx, kSemVmx); + EXPECT_EQ(info.semaem, kSemAem); +} + +TEST(SemaphoreTest, SemInfo) { + constexpr int kLoops = 5; + constexpr int kSemSetSize = 3; + std::set<int> sem_ids; + struct seminfo info; + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + for (int i = 0; i < kLoops; i++) { + AutoSem sem(semget(IPC_PRIVATE, kSemSetSize, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + sem_ids.insert(sem.release()); + } + ASSERT_EQ(sem_ids.size(), kLoops); + int max_used_index = 0; + EXPECT_THAT(max_used_index = semctl(0, 0, SEM_INFO, &info), + SyscallSucceeds()); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + // There could be semaphores existing in the system during the test, which + // prevents the test from getting a exact number, but the test could expect at + // least the number of sempahroes it creates in the begining of the test. + EXPECT_GE(info.semusz, sem_ids.size()); + EXPECT_EQ(info.semvmx, kSemVmx); + EXPECT_GE(info.semaem, sem_ids.size() * kSemSetSize); + + std::set<int> sem_ids_before_max_index; + for (int i = 0; i <= max_used_index; i++) { + struct semid_ds ds = {}; + int sem_id = semctl(i, 0, SEM_STAT, &ds); + // Only if index i is used within the registry. + if (sem_ids.find(sem_id) != sem_ids.end()) { + struct semid_ds ipc_stat_ds; + ASSERT_THAT(semctl(sem_id, 0, IPC_STAT, &ipc_stat_ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.__key, ipc_stat_ds.sem_perm.__key); + EXPECT_EQ(ds.sem_perm.uid, ipc_stat_ds.sem_perm.uid); + EXPECT_EQ(ds.sem_perm.gid, ipc_stat_ds.sem_perm.gid); + EXPECT_EQ(ds.sem_perm.cuid, ipc_stat_ds.sem_perm.cuid); + EXPECT_EQ(ds.sem_perm.cgid, ipc_stat_ds.sem_perm.cgid); + EXPECT_EQ(ds.sem_perm.mode, ipc_stat_ds.sem_perm.mode); + EXPECT_EQ(ds.sem_otime, ipc_stat_ds.sem_otime); + EXPECT_EQ(ds.sem_ctime, ipc_stat_ds.sem_ctime); + EXPECT_EQ(ds.sem_nsems, ipc_stat_ds.sem_nsems); + + // Remove the semaphore set's read permission. + struct semid_ds ipc_set_ds; + ipc_set_ds.sem_perm.uid = getuid(); + ipc_set_ds.sem_perm.gid = getgid(); + // Keep the semaphore set's write permission so that it could be removed. + ipc_set_ds.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem_id, 0, IPC_SET, &ipc_set_ds), SyscallSucceeds()); + ASSERT_THAT(semctl(i, 0, SEM_STAT, &ds), SyscallFailsWithErrno(EACCES)); + + sem_ids_before_max_index.insert(sem_id); + } + } + EXPECT_EQ(sem_ids_before_max_index.size(), kLoops); + for (const int sem_id : sem_ids) { + ASSERT_THAT(semctl(sem_id, 0, IPC_RMID), SyscallSucceeds()); + } - EXPECT_EQ(info.semmap, 1024000000); - EXPECT_EQ(info.semmni, 32000); - EXPECT_EQ(info.semmns, 1024000000); - EXPECT_EQ(info.semmnu, 1024000000); - EXPECT_EQ(info.semmsl, 32000); - EXPECT_EQ(info.semopm, 500); - EXPECT_EQ(info.semume, 500); - EXPECT_EQ(info.semvmx, 32767); - EXPECT_EQ(info.semaem, 32767); + ASSERT_THAT(semctl(0, 0, SEM_INFO, &info), SyscallSucceeds()); + EXPECT_EQ(info.semmap, kSemMap); + EXPECT_EQ(info.semmni, kSemMni); + EXPECT_EQ(info.semmns, kSemMns); + EXPECT_EQ(info.semmnu, kSemMnu); + EXPECT_EQ(info.semmsl, kSemMsl); + EXPECT_EQ(info.semopm, kSemOpm); + EXPECT_EQ(info.semume, kSemUme); + // Apart from semapahores that are not created by the test, we can't determine + // the exact number of semaphore sets and semaphores, as a result, semusz and + // semaem range from 0 to a random number. Since the numbers are always + // non-negative, the test will not check the reslts of semusz and semaem. + EXPECT_EQ(info.semvmx, kSemVmx); } } // namespace diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index e680d3dd7..32f583581 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -46,7 +46,7 @@ TEST(SocketTest, ProtocolUnix) { {AF_UNIX, SOCK_SEQPACKET, PF_UNIX}, {AF_UNIX, SOCK_DGRAM, PF_UNIX}, }; - for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { ASSERT_NO_ERRNO_AND_VALUE( Socket(tests[i].domain, tests[i].type, tests[i].protocol)); } @@ -59,7 +59,7 @@ TEST(SocketTest, ProtocolInet) { {AF_INET, SOCK_DGRAM, IPPROTO_UDP}, {AF_INET, SOCK_STREAM, IPPROTO_TCP}, }; - for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { + for (long unsigned int i = 0; i < ABSL_ARRAYSIZE(tests); i++) { ASSERT_NO_ERRNO_AND_VALUE( Socket(tests[i].domain, tests[i].type, tests[i].protocol)); } @@ -87,7 +87,7 @@ TEST(SocketTest, UnixSocketStat) { ASSERT_THAT(stat(addr.sun_path, &statbuf), SyscallSucceeds()); // Mode should be S_IFSOCK. - EXPECT_EQ(statbuf.st_mode, S_IFSOCK | sock_perm & ~mask); + EXPECT_EQ(statbuf.st_mode, S_IFSOCK | (sock_perm & ~mask)); // Timestamps should be equal and non-zero. // TODO(b/158882152): Sockets currently don't implement timestamps. diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 5ed57625c..06419772f 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -168,7 +168,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { std::vector<std::unique_ptr<ScopedThread>> listen_threads( listener_fds.size()); - for (int i = 0; i < listener_fds.size(); i++) { + for (long unsigned int i = 0; i < listener_fds.size(); i++) { listen_threads[i] = absl::make_unique<ScopedThread>( [&listener_fds, &accept_counts, &connects_received, i, kConnectAttempts]() { @@ -235,7 +235,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { listen_thread->Join(); } // Check that connections are distributed correctly among listening sockets. - for (int i = 0; i < accept_counts.size(); i++) { + for (long unsigned int i = 0; i < accept_counts.size(); i++) { EXPECT_THAT( accept_counts[i], EquivalentWithin(static_cast<int>(kConnectAttempts * @@ -308,7 +308,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { std::vector<std::unique_ptr<ScopedThread>> receiver_threads( listener_fds.size()); - for (int i = 0; i < listener_fds.size(); i++) { + for (long unsigned int i = 0; i < listener_fds.size(); i++) { receiver_threads[i] = absl::make_unique<ScopedThread>( [&listener_fds, &packets_per_socket, &packets_received, i]() { do { @@ -366,7 +366,7 @@ TEST_P(BindToDeviceDistributionTest, Udp) { receiver_thread->Join(); } // Check that packets are distributed correctly among listening sockets. - for (int i = 0; i < packets_per_socket.size(); i++) { + for (long unsigned int i = 0; i < packets_per_socket.size(); i++) { EXPECT_THAT( packets_per_socket[i], EquivalentWithin(static_cast<int>(kConnectAttempts * diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index 70cc86b16..a28ee2233 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -853,5 +853,21 @@ TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) { } } +TEST_P(AllSocketPairTest, GetSocketOutOfBandInlineOption) { + // We do not support disabling this option. It is always enabled. + SKIP_IF(!IsRunningOnGvisor()); + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int enable = -1; + socklen_t enableLen = sizeof(enable); + + int want = 1; + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_OOBINLINE, &enable, + &enableLen), + SyscallSucceeds()); + ASSERT_EQ(enableLen, sizeof(enable)); + EXPECT_EQ(enable, want); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index f69f8f99f..1694e188a 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -15,6 +15,9 @@ #include "test/syscalls/linux/socket_ip_udp_generic.h" #include <errno.h> +#ifdef __linux__ +#include <linux/in6.h> +#endif // __linux__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -356,6 +359,58 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) { EXPECT_EQ(get_len, sizeof(get)); } +// Test getsockopt for a socket which is not set with IP_RECVORIGDSTADDR option. +TEST_P(UDPSocketPairTest, ReceiveOrigDstAddrDefault) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int get = -1; + socklen_t get_len = sizeof(get); + int level = SOL_IP; + int type = IP_RECVORIGDSTADDR; + if (sockets->first_addr()->sa_family == AF_INET6) { + level = SOL_IPV6; + type = IPV6_RECVORIGDSTADDR; + } + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kSockOptOff); +} + +// Test setsockopt and getsockopt for a socket with IP_RECVORIGDSTADDR option. +TEST_P(UDPSocketPairTest, SetAndGetReceiveOrigDstAddr) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int level = SOL_IP; + int type = IP_RECVORIGDSTADDR; + if (sockets->first_addr()->sa_family == AF_INET6) { + level = SOL_IPV6; + type = IPV6_RECVORIGDSTADDR; + } + + // Check getsockopt before IP_PKTINFO is set. + int get = -1; + socklen_t get_len = sizeof(get); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOn); + EXPECT_EQ(get_len, sizeof(get)); + + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff, + sizeof(kSockOptOff)), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get, kSockOptOff); + EXPECT_EQ(get_len, sizeof(get)); +} + // Holds TOS or TClass information for IPv4 or IPv6 respectively. struct RecvTosOption { int level; @@ -438,7 +493,7 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) { // This should only test AF_INET6 sockets for the mismatch behavior. SKIP_IF(GetParam().domain != AF_INET6); // IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW. - SKIP_IF(GetParam().type != SOCK_DGRAM | GetParam().type != SOCK_RAW); + SKIP_IF((GetParam().type != SOCK_DGRAM) | (GetParam().type != SOCK_RAW)); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index b3f54e7f6..e557572a7 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2222,6 +2222,90 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Test that socket will receive IP_RECVORIGDSTADDR control message. +TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_addr = V4Loopback(); + int level = SOL_IP; + int type = IP_RECVORIGDSTADDR; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Retrieve the port bound by the receiver. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + ASSERT_THAT( + connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Get address and port bound by the sender. + sockaddr_storage sender_addr_storage; + socklen_t sender_addr_len = sizeof(sender_addr_storage); + ASSERT_THAT(getsockname(sender->get(), + reinterpret_cast<sockaddr*>(&sender_addr_storage), + &sender_addr_len), + SyscallSucceeds()); + ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in)); + + // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination + // address of the datagram as auxiliary information in the control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in))] = {}; + size_t cmsg_data_len = sizeof(sockaddr_in); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/), + IsPosixErrorOkAndHolds(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Check the data + sockaddr_in received_addr = {}; + memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr)); + auto orig_receiver_addr = reinterpret_cast<sockaddr_in*>(&receiver_addr.addr); + EXPECT_EQ(received_addr.sin_addr.s_addr, orig_receiver_addr->sin_addr.s_addr); + EXPECT_EQ(received_addr.sin_port, orig_receiver_addr->sin_port); +} + // Check that setting SO_RCVBUF below min is clamped to the minimum // receive buffer size. TEST_P(IPv4UDPUnboundSocketTest, SetSocketRecvBufBelowMin) { diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc index 875016812..9a9ddc297 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -177,7 +177,7 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { // Broadcasts from each socket should be received by every socket (including // the sending socket). - for (int w = 0; w < socks.size(); w++) { + for (long unsigned int w = 0; w < socks.size(); w++) { auto& w_sock = socks[w]; ASSERT_THAT( RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, @@ -187,7 +187,7 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { << "write socks[" << w << "]"; // Check that we received the packet on all sockets. - for (int r = 0; r < socks.size(); r++) { + for (long unsigned int r = 0; r < socks.size(); r++) { auto& r_sock = socks[r]; struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.cc b/test/syscalls/linux/socket_ipv6_udp_unbound.cc new file mode 100644 index 000000000..08526468e --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound.cc @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv6_udp_unbound.h" + +#include <arpa/inet.h> +#include <netinet/in.h> +#ifdef __linux__ +#include <linux/in6.h> +#endif // __linux__ +#include <net/if.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/posix_error.h" +#include "test/util/save_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test that socket will receive IP_RECVORIGDSTADDR control message. +TEST_P(IPv6UDPUnboundSocketTest, SetAndReceiveIPReceiveOrigDstAddr) { + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_addr = V6Loopback(); + int level = SOL_IPV6; + int type = IPV6_RECVORIGDSTADDR; + + ASSERT_THAT( + bind(receiver->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Retrieve the port bound by the receiver. + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + ASSERT_THAT( + connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + + // Get address and port bound by the sender. + sockaddr_storage sender_addr_storage; + socklen_t sender_addr_len = sizeof(sender_addr_storage); + ASSERT_THAT(getsockname(sender->get(), + reinterpret_cast<sockaddr*>(&sender_addr_storage), + &sender_addr_len), + SyscallSucceeds()); + ASSERT_EQ(sender_addr_len, sizeof(struct sockaddr_in6)); + + // Enable IP_RECVORIGDSTADDR on socket so that we get the original destination + // address of the datagram as auxiliary information in the control message. + ASSERT_THAT( + setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + msghdr sent_msg = {}; + iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = sent_data; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + sent_msg.msg_flags = 0; + + ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + msghdr received_msg = {}; + iovec received_iov = {}; + char received_data[kDataLength]; + char received_cmsg_buf[CMSG_SPACE(sizeof(sockaddr_in6))] = {}; + size_t cmsg_data_len = sizeof(sockaddr_in6); + received_iov.iov_base = received_data; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + received_msg.msg_control = received_cmsg_buf; + + ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/), + IsPosixErrorOkAndHolds(kDataLength)); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, level); + EXPECT_EQ(cmsg->cmsg_type, type); + + // Check that the received address in the control message matches the expected + // receiver's address. + sockaddr_in6 received_addr = {}; + memcpy(&received_addr, CMSG_DATA(cmsg), sizeof(received_addr)); + auto orig_receiver_addr = + reinterpret_cast<sockaddr_in6*>(&receiver_addr.addr); + EXPECT_EQ(memcmp(&received_addr.sin6_addr, &orig_receiver_addr->sin6_addr, + sizeof(in6_addr)), + 0); + EXPECT_EQ(received_addr.sin6_port, orig_receiver_addr->sin6_port); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound.h b/test/syscalls/linux/socket_ipv6_udp_unbound.h new file mode 100644 index 000000000..71e160f73 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv6 UDP sockets. +using IPv6UDPUnboundSocketTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc new file mode 100644 index 000000000..058336ecc --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv6_udp_unbound.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv6UDPSockets, IPv6UDPUnboundSocketTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_unbound_filesystem.cc b/test/syscalls/linux/socket_unix_unbound_filesystem.cc index cab912152..a035fb095 100644 --- a/test/syscalls/linux/socket_unix_unbound_filesystem.cc +++ b/test/syscalls/linux/socket_unix_unbound_filesystem.cc @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <fcntl.h> #include <stdio.h> #include <sys/un.h> #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" namespace gvisor { @@ -70,6 +72,20 @@ TEST_P(UnboundFilesystemUnixSocketPairTest, GetSockNameLength) { strlen(want_addr.sun_path) + 1 + sizeof(want_addr.sun_family)); } +TEST_P(UnboundFilesystemUnixSocketPairTest, OpenSocketWithTruncate) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + const struct sockaddr_un *addr = + reinterpret_cast<const struct sockaddr_un *>(sockets->first_addr()); + EXPECT_THAT(chmod(addr->sun_path, 0777), SyscallSucceeds()); + EXPECT_THAT(open(addr->sun_path, O_RDONLY | O_TRUNC), + SyscallFailsWithErrno(ENXIO)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnboundFilesystemUnixSocketPairTest, ::testing::ValuesIn(ApplyVec<SocketPairKind>( diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc index 97d554e72..5ac337d1a 100644 --- a/test/syscalls/linux/tuntap.cc +++ b/test/syscalls/linux/tuntap.cc @@ -324,8 +324,9 @@ TEST_F(TuntapTest, PingKernel) { }; while (1) { inpkt r = {}; - int n = read(fd.get(), &r, sizeof(r)); - EXPECT_THAT(n, SyscallSucceeds()); + int nread = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(nread, SyscallSucceeds()); + long unsigned int n = static_cast<long unsigned int>(nread); if (n < sizeof(pihdr)) { std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol @@ -383,8 +384,9 @@ TEST_F(TuntapTest, SendUdpTriggersArpResolution) { }; while (1) { inpkt r = {}; - int n = read(fd.get(), &r, sizeof(r)); - EXPECT_THAT(n, SyscallSucceeds()); + int nread = read(fd.get(), &r, sizeof(r)); + EXPECT_THAT(nread, SyscallSucceeds()); + long unsigned int n = static_cast<long unsigned int>(nread); if (n < sizeof(pihdr)) { std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 90ef8bf21..21727a2e7 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -14,6 +14,8 @@ #include <arpa/inet.h> #include <fcntl.h> +#include <netinet/icmp6.h> +#include <netinet/ip_icmp.h> #include <ctime> @@ -779,6 +781,94 @@ TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) { SyscallFailsWithErrno(ECONNREFUSED)); } +#ifdef __linux__ +TEST_P(UdpSocketTest, RecvErrorConnRefused) { + // We will simulate an ICMP error and verify that we do receive that error via + // recvmsg(MSG_ERRQUEUE). + ASSERT_NO_ERRNO(BindLoopback()); + // Close the socket to release the port so that we get an ICMP error. + ASSERT_THAT(close(bind_.release()), SyscallSucceeds()); + + // Set IP_RECVERR socket option to enable error queueing. + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + int opt_level = SOL_IP; + int opt_type = IP_RECVERR; + if (GetParam() != AddressFamily::kIpv4) { + opt_level = SOL_IPV6; + opt_type = IPV6_RECVERR; + } + ASSERT_THAT(setsockopt(sock_.get(), opt_level, opt_type, &v, optlen), + SyscallSucceeds()); + + // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an + // UDP socket. There is no easy way to ensure that the UDP port is not bound + // by another conncurrently running test. *This is potentially flaky*. + const int kBufLen = 300; + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + char buf[kBufLen]; + RandomizeBuffer(buf, sizeof(buf)); + // Send from sock_ to an unbound port. This should cause ECONNREFUSED. + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + // Dequeue error using recvmsg(MSG_ERRQUEUE). + char got[kBufLen]; + struct iovec iov; + iov.iov_base = reinterpret_cast<void*>(got); + iov.iov_len = kBufLen; + + size_t control_buf_len = CMSG_SPACE(sizeof(sock_extended_err) + addrlen_); + char* control_buf = static_cast<char*>(calloc(1, control_buf_len)); + struct sockaddr_storage remote; + memset(&remote, 0, sizeof(remote)); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_flags = 0; + msg.msg_control = control_buf; + msg.msg_controllen = control_buf_len; + msg.msg_name = reinterpret_cast<void*>(&remote); + msg.msg_namelen = addrlen_; + ASSERT_THAT(recvmsg(sock_.get(), &msg, MSG_ERRQUEUE), + SyscallSucceedsWithValue(kBufLen)); + + // Check the contents of msg. + EXPECT_EQ(memcmp(got, buf, sizeof(buf)), 0); // iovec check + EXPECT_NE(msg.msg_flags & MSG_ERRQUEUE, 0); + EXPECT_EQ(memcmp(&remote, bind_addr_, addrlen_), 0); + + // Check the contents of the control message. + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(CMSG_NXTHDR(&msg, cmsg), nullptr); + EXPECT_EQ(cmsg->cmsg_level, opt_level); + EXPECT_EQ(cmsg->cmsg_type, opt_type); + + // Check the contents of socket error. + struct sock_extended_err* sock_err = + (struct sock_extended_err*)CMSG_DATA(cmsg); + EXPECT_EQ(sock_err->ee_errno, ECONNREFUSED); + if (GetParam() == AddressFamily::kIpv4) { + EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP); + EXPECT_EQ(sock_err->ee_type, ICMP_DEST_UNREACH); + EXPECT_EQ(sock_err->ee_code, ICMP_PORT_UNREACH); + } else { + EXPECT_EQ(sock_err->ee_origin, SO_EE_ORIGIN_ICMP6); + EXPECT_EQ(sock_err->ee_type, ICMP6_DST_UNREACH); + EXPECT_EQ(sock_err->ee_code, ICMP6_DST_UNREACH_NOPORT); + } + + // Now verify that the socket error was cleared by recvmsg(MSG_ERRQUEUE). + int err; + optlen = sizeof(err); + ASSERT_THAT(getsockopt(sock_.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, 0); + ASSERT_EQ(optlen, sizeof(err)); +} +#endif // __linux__ + TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); diff --git a/tools/bazel.mk b/tools/bazel.mk index ca5621a9c..9b8def713 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -14,52 +14,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Make hacks. -EMPTY := -SPACE := $(EMPTY) $(EMPTY) +## +## Docker options. +## +## This file supports targets that wrap bazel in a running Docker +## container to simplify development. Some options are available to +## control the behavior of this container: +## +## USER - The in-container user. +## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests). +## DOCKER_NAME - The container name (default: gvisor-bazel-HASH). +## DOCKER_PRIVILEGED - Docker privileged flags (default: --privileged). +## BAZEL_CACHE - The bazel cache directory (default: detected). +## GCLOUD_CONFIG - The gcloud config directory (detect: detected). +## DOCKER_SOCKET - The Docker socket (default: detected). +## +## To opt out of these wrappers, set DOCKER_BUILD=false. +DOCKER_BUILD := true +ifeq ($(DOCKER_BUILD),true) +-include bazel-server +endif # See base Makefile. -SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ - git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ - xargs -n 1 basename 2>/dev/null) + git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ + xargs -n 1 basename 2>/dev/null) BUILD_ROOTS := bazel-bin/ bazel-out/ # Bazel container configuration (see below). USER := $(shell whoami) HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) -BUILDER_BASE := gvisor.dev/images/default -BUILDER_IMAGE := gvisor.dev/images/builder -BUILDER_NAME := gvisor-builder-$(HASH) -DOCKER_NAME := gvisor-bazel-$(HASH) +BUILDER_NAME := gvisor-builder-$(HASH)-$(ARCH) +DOCKER_NAME := gvisor-bazel-$(HASH)-$(ARCH) DOCKER_PRIVILEGED := --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock DOCKER_CONFIG := /etc/docker -# Bazel flags. -BAZEL := bazel $(STARTUP_OPTIONS) -BASE_OPTIONS := --color=no --curses=no -ifneq (,$(BAZEL_CONFIG)) -BASE_OPTIONS += --config=$(BAZEL_CONFIG) -endif +## +## Bazel helpers. +## +## Bazel will be run with standard flags. You can specify the following flags +## to control which flags are passed: +## +## STARTUP_OPTIONS - Startup options passed to Bazel. +## +STARTUP_OPTIONS := +BAZEL := bazel $(STARTUP_OPTIONS) +BASE_OPTIONS := --color=no --curses=no +TEST_OPTIONS := $(BASE_OPTIONS) \ + --test_output=errors \ + --keep_going \ + --verbose_failures=true \ + --build_event_json_file=.build_events.json # Basic options. UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) USERADD_OPTIONS := -FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) -FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID) -FULL_DOCKER_RUN_OPTIONS += --entrypoint "" -FULL_DOCKER_RUN_OPTIONS += --init -FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" -FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" -FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" -FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) -FULL_DOCKER_EXEC_OPTIONS += --interactive -ifeq (true,$(shell [[ -t 0 ]] && echo true)) -FULL_DOCKER_EXEC_OPTIONS += --tty +DOCKER_RUN_OPTIONS := +DOCKER_RUN_OPTIONS += --user $(UID):$(GID) +DOCKER_RUN_OPTIONS += --entrypoint "" +DOCKER_RUN_OPTIONS += --init +DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" +DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" +DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" +DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) +DOCKER_EXEC_OPTIONS += --interactive +ifeq (true,$(shell test -t 0 && echo true)) +DOCKER_EXEC_OPTIONS += --tty endif # Add basic UID/GID options. @@ -83,86 +107,75 @@ endif # Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) -FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" -FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)" -FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) -FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) +DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" +DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)" +DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) +DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) ifneq ($(GID),$(DOCKER_GROUP)) USERADD_OPTIONS += --groups $(DOCKER_GROUP) GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) && -FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) +DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) endif endif # Add KVM passthrough options. ifneq (,$(wildcard /dev/kvm)) -FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm +DOCKER_RUN_OPTIONS += --device=/dev/kvm KVM_GROUP := $(shell stat -c '%g' /dev/kvm) ifneq ($(GID),$(KVM_GROUP)) USERADD_OPTIONS += --groups $(KVM_GROUP) GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) && -FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP) +DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP) endif endif -bazel-image: load-default - @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi - docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ - $(BUILDER_BASE) \ - sh -c "$(GROUPADD_DOCKER) \ - $(USERADD_DOCKER) \ - if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" - docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) - @docker rm -f $(BUILDER_NAME) -.PHONY: bazel-image - -## -## Bazel helpers. -## -## This file supports targets that wrap bazel in a running Docker -## container to simplify development. Some options are available to -## control the behavior of this container: -## USER - The in-container user. -## DOCKER_RUN_OPTIONS - Options for the container (default: --privileged, required for tests). -## DOCKER_NAME - The container name (default: gvisor-bazel-HASH). -## BAZEL_CACHE - The bazel cache directory (default: detected). -## GCLOUD_CONFIG - The gcloud config directory (detect: detected). -## DOCKER_SOCKET - The Docker socket (default: detected). -## -bazel-server-start: bazel-image ## Starts the bazel server. - @mkdir -p $(BAZEL_CACHE) - @mkdir -p $(GCLOUD_CONFIG) - @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi - # This command runs a bazel server, and the container sticks around - # until the bazel server exits. This should ensure that it does not - # exit in the middle of running a build, but also it won't stick around - # forever. The build commands wrap around an appropriate exec into the - # container in order to perform work via the bazel client. - docker run -d --rm --name $(DOCKER_NAME) \ - -v "$(CURDIR):$(CURDIR)" \ - --workdir "$(CURDIR)" \ - $(FULL_DOCKER_RUN_OPTIONS) \ - $(BUILDER_IMAGE) \ - sh -c "tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null" -.PHONY: bazel-server-start +# Top-level functions. +# +# This command runs a bazel server, and the container sticks around +# until the bazel server exits. This should ensure that it does not +# exit in the middle of running a build, but also it won't stick around +# forever. The build commands wrap around an appropriate exec into the +# container in order to perform work via the bazel client. +ifeq ($(DOCKER_BUILD),true) +wrapper = docker exec $(DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(1) +else +wrapper = $(1) +endif bazel-shutdown: ## Shuts down a running bazel server. - @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \ - rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] + @$(call wrapper,$(BAZEL) shutdown) .PHONY: bazel-shutdown bazel-alias: ## Emits an alias that can be used within the shell. - @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'" + @echo "alias bazel='$(call wrapper,$(BAZEL))'" .PHONY: bazel-alias -bazel-server: ## Ensures that the server exists. Used as an internal target. - @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true >&2 || $(MAKE) bazel-server-start >&2 -.PHONY: bazel-server +bazel-image: load-default ## Ensures that the local builder exists. + @$(call header,DOCKER BUILD) + @docker rm -f $(BUILDER_NAME) 2>/dev/null || true + @docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) gvisor.dev/images/default \ + sh -c "$(GROUPADD_DOCKER) $(USERADD_DOCKER) if test -e /dev/kvm; then chmod a+rw /dev/kvm; fi" >&2 + @docker commit $(BUILDER_NAME) gvisor.dev/images/builder >&2 +.PHONY: bazel-image -# build_cmd builds the given targets in the bazel-server container. -build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c \ - '$(BAZEL) build $(BASE_OPTIONS) $(OPTIONS) "$(TARGETS)"' +ifneq (true,$(shell $(wrapper echo true))) +bazel-server: bazel-image ## Ensures that the server exists. + @$(call header,DOCKER RUN) + @docker rm -f $(DOCKER_NAME) 2>/dev/null || true + @mkdir -p $(GCLOUD_CONFIG) + @mkdir -p $(BAZEL_CACHE) + @docker run -d --rm --name $(DOCKER_NAME) \ + -v "$(CURDIR):$(CURDIR)" \ + --workdir "$(CURDIR)" \ + $(DOCKER_RUN_OPTIONS) \ + gvisor.dev/images/builder \ + sh -c "set -x; tail -f --pid=\$$($(BAZEL) info server_pid) /dev/null" >&2 +else +bazel-server: + @ +endif +.PHONY: bazel-server # build_paths extracts the built binary from the bazel stderr output. # @@ -172,49 +185,33 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai # command here? Yikes, let's just stick with the ugly shell pipeline. # # The last line is used to prevent terminal shenanigans. -build_paths = command_line=$$( $(build_cmd) 2>&1 \ - | grep -A1 -E '^Target' \ - | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \ - | sed "s/ /\n/g" \ - | strings -n 10 \ - | awk '{$$1=$$1};1' \ - | xargs -n 1 -I {} readlink -f "{}" \ - | xargs -n 1 -I {} echo "$(1)" ) && \ - (set -xeuo pipefail; eval $${command_line}) - -build: bazel-server - @$(call build_cmd) -.PHONY: build - -copy: bazel-server -ifeq (,$(DESTINATION)) - $(error Destination not provided.) -endif - @$(call build_paths,cp -fa {} $(DESTINATION)) - -run: bazel-server - @$(call build_paths,{} $(ARGS)) -.PHONY: run - -sudo: bazel-server - @$(call build_paths,sudo -E {} $(ARGS)) -.PHONY: sudo - -test: bazel-server - @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) \ - $(BAZEL) test $(BASE_OPTIONS) \ - --test_output=errors --keep_going --verbose_failures=true \ - --build_event_json_file=.build_events.json \ - $(OPTIONS) $(TARGETS) -.PHONY: test - -testlogs: - @cat .build_events.json | jq -r \ - 'select(.testSummary?.overallStatus? | tostring | test("(FAILED|FLAKY|TIMEOUT)")) | .testSummary.failed | .[] | .uri' | \ - awk -Ffile:// '{print $$2;}' +build_paths = \ + (set -euo pipefail; \ + $(call wrapper,$(BAZEL) build $(BASE_OPTIONS) $(1)) 2>&1 \ + | tee /proc/self/fd/2 \ + | grep -A1 -E '^Target' \ + | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \ + | sed "s/ /\n/g" \ + | strings -n 10 \ + | awk '{$$1=$$1};1' \ + | xargs -n 1 -I {} readlink -f "{}" \ + | xargs -n 1 -I {} bash -c 'set -xeuo pipefail; $(2)') + +clean = $(call header,CLEAN) && $(call wrapper,$(BAZEL) clean) +build = $(call header,BUILD $(1)) && $(call build_paths,$(1),echo {}) +copy = $(call header,COPY $(1) $(2)) && $(call build_paths,$(1),cp -fa {} $(2)) +run = $(call header,RUN $(1) $(2)) && $(call build_paths,$(1),{} $(2)) +sudo = $(call header,SUDO $(1) $(2)) && $(call build_paths,$(1),sudo -E {} $(2)) +test = $(call header,TEST $(1)) && $(call wrapper,$(BAZEL) test $(TEST_OPTIONS) $(1)) + +clean: ## Cleans the bazel cache. + @$(call clean) +.PHONY: clean + +testlogs: ## Returns the most recent set of test logs. + @if test -f .build_events.json; then \ + cat .build_events.json | jq -r \ + 'select(.testSummary?.overallStatus? | tostring | test("(FAILED|FLAKY|TIMEOUT)")) | "\(.id.testSummary.label) \(.testSummary.failed[].uri)"' | \ + sed -e 's|file://||'; \ + fi .PHONY: testlogs - -query: bazel-server - @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c \ - '$(BAZEL) query $(BASE_OPTIONS) $(OPTIONS) "$(TARGETS)" 2>/dev/null' -.PHONY: query diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index 27e85a75e..a4a605346 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -1,46 +1,7 @@ -load("//tools:defs.bzl", "bzl_library", "rbe_platform", "rbe_toolchain") +load("//tools:defs.bzl", "bzl_library") package(licenses = ["notice"]) -# We need to define a bazel platform and toolchain to specify dockerPrivileged -# and dockerRunAsRoot options, they are required to run tests on the RBE -# cluster in Kokoro. -rbe_platform( - name = "rbe_ubuntu1604", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains//constraints:xenial", - "@bazel_toolchains//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -rbe_toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [], - tags = [ - "manual", - ], - target_compatible_with = [], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) - bzl_library( name = "platforms_bzl", srcs = ["platforms.bzl"], @@ -58,3 +19,21 @@ bzl_library( srcs = ["defs.bzl"], visibility = ["//visibility:private"], ) + +config_setting( + name = "linux_arm64_cross", + values = { + "cpu": "aarch64", + "host_cpu": "k8", + }, + visibility = ["//visibility:private"], +) + +config_setting( + name = "linux_amd64_cross", + values = { + "cpu": "k8", + "host_cpu": "aarch64", + }, + visibility = ["//visibility:private"], +) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 51e17a79a..58ced5167 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -5,8 +5,6 @@ load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library") build_test = _build_test bzl_library = _bzl_library -rbe_platform = native.platform -rbe_toolchain = native.toolchain more_shards = 4 most_shards = 8 @@ -39,3 +37,44 @@ def default_net_util(): def coreutil(): return [] # Nothing needed. + +def select_native_vs_cross(native = [], amd64 = [], arm64 = [], cross = []): + values = { + "//tools/bazeldefs:linux_arm64_cross": arm64 + cross, + "//tools/bazeldefs:linux_amd64_cross": amd64 + cross, + "//conditions:default": native, + } + return select(values) + +def arch_genrule(name, srcs, outs, cmd, tools): + """Runs a gen command on the target architecture. + + If the target architecture isn't match the host architecture, it will build + a command for the target architecture and run it via qemu. + + The native genrule runs the command on the host architecture. + + Args: + name: name of generated target. + srcs: A list of inputs for this rule. + cmd: The command to run. It has to contain " QEMU " before executed binaries. + outs: A list of files generated by this rule. + tools: A list of tool dependencies for this rule. + """ + qemu_arm64 = "qemu-aarch64-static" + qemu_amd64 = "qemu-x86_64-static" + srcs = select_native_vs_cross( + cross = srcs + tools, + native = srcs, + ) + tools = select_native_vs_cross( + cross = [], + native = tools, + ) + cmd = select_native_vs_cross( + arm64 = cmd.replace("QEMU", qemu_arm64), + amd64 = cmd.replace("QEMU", qemu_amd64), + native = cmd.replace("QEMU", ""), + cross = "", + ) + native.genrule(name = name, srcs = srcs, outs = outs, cmd = cmd, tools = tools) diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl index 661c9727e..bcd8cffe7 100644 --- a/tools/bazeldefs/go.bzl +++ b/tools/bazeldefs/go.bzl @@ -28,7 +28,7 @@ def go_proto_library(name, **kwargs): def go_grpc_and_proto_libraries(name, **kwargs): _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs) -def go_binary(name, static = False, pure = False, x_defs = None, **kwargs): +def go_binary(name, static = False, pure = False, x_defs = None, system_malloc = False, **kwargs): """Build a go binary. Args: @@ -52,7 +52,7 @@ def go_importpath(target): """Returns the importpath for the target.""" return target[GoLibrary].importpath -def go_library(name, **kwargs): +def go_library(name, arch_deps = [], **kwargs): _go_library( name = name, importpath = "gvisor.dev/gvisor/" + native.package_name(), diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 544af3876..a4ca93ec2 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -21,6 +21,7 @@ package bigquery import ( "context" "fmt" + "strconv" "strings" "time" @@ -109,6 +110,12 @@ func NewBenchmark(name string, iters int) *Benchmark { return &Benchmark{ Name: name, Metric: make([]*Metric, 0), + Condition: []*Condition{ + { + Name: "iterations", + Value: strconv.Itoa(iters), + }, + }, } } diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go index 27991649f..f46eba39b 100644 --- a/tools/checkescape/test1/test1.go +++ b/tools/checkescape/test1/test1.go @@ -36,17 +36,20 @@ func (t Type) Foo() { fmt.Printf("%v", t) // Never executed. } +// InterfaceFunction is passed an interface argument. // +checkescape:all,hard //go:nosplit func InterfaceFunction(i Interface) { // Do nothing; exported for tests. } +// TypeFunction is passed a concrete pointer argument. // +checkesacape:all,hard //go:nosplit func TypeFunction(t *Type) { } +// BuiltinMap creates a new map. // +mustescape:local,builtin //go:noinline //go:nosplit @@ -61,7 +64,8 @@ func builtinMapRec(x int) map[string]bool { return BuiltinMap(x) } -// +temustescapestescape:local,builtin +// BuiltinClosure returns a closure around x. +// +mustescape:local,builtin //go:noinline //go:nosplit func BuiltinClosure(x int) func() { @@ -77,6 +81,7 @@ func builtinClosureRec(x int) func() { return BuiltinClosure(x) } +// BuiltinMakeSlice makes a new slice. // +mustescape:local,builtin //go:noinline //go:nosplit @@ -91,6 +96,7 @@ func builtinMakeSliceRec(x int) []byte { return BuiltinMakeSlice(x) } +// BuiltinAppend calls append on a slice. // +mustescape:local,builtin //go:noinline //go:nosplit @@ -105,6 +111,7 @@ func builtinAppendRec() []byte { return BuiltinAppend(nil) } +// BuiltinChan makes a channel. // +mustescape:local,builtin //go:noinline //go:nosplit @@ -119,6 +126,7 @@ func builtinChanRec() chan int { return BuiltinChan() } +// Heap performs an explicit heap allocation. // +mustescape:local,heap //go:noinline //go:nosplit @@ -134,6 +142,7 @@ func heapRec() *Type { return Heap() } +// Dispatch dispatches via an interface. // +mustescape:local,interface //go:noinline //go:nosplit @@ -148,6 +157,7 @@ func dispatchRec(i Interface) { Dispatch(i) } +// Dynamic invokes a dynamic function. // +mustescape:local,dynamic //go:noinline //go:nosplit @@ -167,6 +177,7 @@ func dynamicRec(f func()) { func internalFunc() { } +// Split includes a guaranteed stack split. // +mustescape:local,stack //go:noinline func Split() { diff --git a/tools/defs.bzl b/tools/defs.bzl index b6f188aeb..56c481f44 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -8,7 +8,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") load("//tools/nogo:defs.bzl", "nogo_test") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path") +load("//tools/bazeldefs:defs.bzl", _arch_genrule = "arch_genrule", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _more_shards = "more_shards", _most_shards = "most_shards", _proto_library = "proto_library", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path") load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") @@ -16,6 +16,7 @@ load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", load("//tools/bazeldefs:tags.bzl", "go_suffixes") # Core rules. +arch_genrule = _arch_genrule build_test = _build_test bzl_library = _bzl_library default_installer = _default_installer @@ -23,8 +24,6 @@ default_net_util = _default_net_util select_arch = _select_arch select_system = _select_system short_path = _short_path -rbe_platform = _rbe_platform -rbe_toolchain = _rbe_toolchain coreutil = _coreutil more_shards = _more_shards most_shards = _most_shards @@ -184,6 +183,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F name + suffix + "_state_autogen.go" for suffix in state_sets.keys() ] + if "//pkg/state" not in all_deps: all_deps = all_deps + ["//pkg/state"] diff --git a/tools/go_branch.sh b/tools/go_branch.sh index 768a37b9a..7ef4ddf83 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -89,8 +89,14 @@ git merge --no-commit --strategy ours "${head}" || \ find . -type f -exec chmod 0644 {} \; find . -type d -exec chmod 0755 {} \; -# Sync the entire gopath_dir. -rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" . +# Sync the entire gopath_dir. Note that we exclude auto-generated source +# files that will change here. Otherwise, it adds a tremendous amount of noise +# to commits. If this file disappears in the future, then presumably we will +# still delete the underlying directory. +rsync --recursive --delete \ + --exclude .git \ + --exclude webhook/pkg/injector/certs.go \ + -L "${gopath_dir}/" . # Add additional files. for file in "${othersrc[@]}"; do diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index ad97208a8..50e2546bf 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -67,7 +67,7 @@ def _go_template_instance_impl(ctx): # Check that all defined types are expected by the template. for t in ctx.attr.types: if (t not in info.types) and (t not in info.opt_types): - fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) + fail("Type %s is not a parameter to %s" % (t, ctx.attr.template.label)) # Check that all required consts are defined. for t in info.consts: @@ -77,7 +77,7 @@ def _go_template_instance_impl(ctx): # Check that all defined consts are expected by the template. for t in ctx.attr.consts: if (t not in info.consts) and (t not in info.opt_consts): - fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) + fail("Const %s is not a parameter to %s" % (t, ctx.attr.template.label)) # Build the argument list. args = ["-i=%s" % info.template.path, "-o=%s" % output.path] diff --git a/tools/go_generics/generics.go b/tools/go_generics/generics.go index 0860ca9db..30584006c 100644 --- a/tools/go_generics/generics.go +++ b/tools/go_generics/generics.go @@ -223,7 +223,7 @@ func main() { } else { switch kind { case globals.KindType, globals.KindVar, globals.KindConst, globals.KindFunction: - if ident.Name != "_" { + if ident.Name != "_" && !(ident.Name == "init" && kind == globals.KindFunction) { ident.Name = *prefix + ident.Name + *suffix } case globals.KindTag: diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 4a53d25be..28ae6c4ef 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -213,10 +213,11 @@ type sliceAPI struct { type marshallableType struct { spec *ast.TypeSpec slice *sliceAPI + recv string } -func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) marshallableType { - mt := marshallableType{ +func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.TypeSpec) *marshallableType { + mt := &marshallableType{ spec: spec, slice: nil, } @@ -261,12 +262,31 @@ func newMarshallableType(fset *token.FileSet, tagLine *ast.Comment, spec *ast.Ty // collectMarshallableTypes walks the parsed AST and collects a list of type // declarations for which we need to generate the Marshallable interface. -func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []marshallableType { - var types []marshallableType +func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) map[*ast.TypeSpec]*marshallableType { + recv := make(map[string]string) // Type name to recevier name. + types := make(map[*ast.TypeSpec]*marshallableType) for _, decl := range a.Decls { gdecl, ok := decl.(*ast.GenDecl) // Type declaration? if !ok || gdecl.Tok != token.TYPE { + // Is this a function declaration? We remember receiver names. + d, ok := decl.(*ast.FuncDecl) + if ok && d.Recv != nil && len(d.Recv.List) == 1 { + // Accept concrete methods & pointer methods. + ident, ok := d.Recv.List[0].Type.(*ast.Ident) + if !ok { + var st *ast.StarExpr + st, ok = d.Recv.List[0].Type.(*ast.StarExpr) + if ok { + ident, ok = st.X.(*ast.Ident) + } + } + // The receiver name may be not present. + if ok && len(d.Recv.List[0].Names) == 1 { + // Recover the type receiver name in this case. + recv[ident.Name] = d.Recv.List[0].Names[0].Name + } + } debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n") continue } @@ -305,10 +325,20 @@ func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []ma // don't support it. abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name)) } - types = append(types, newMarshallableType(f, tagLine, t)) - + types[t] = newMarshallableType(f, tagLine, t) } } + // Update the types with the last seen receiver. As long as the + // receiver name is consistent for the type, then we will generate + // code that is still consistent with itself. + for t, mt := range types { + r, ok := recv[t.Name.Name] + if !ok { + mt.recv = receiverName(t) // Default. + continue + } + mt.recv = r // Last seen. + } return types } @@ -345,8 +375,8 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp } -func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interfaceGenerator { - i := newInterfaceGenerator(t.spec, fset) +func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *interfaceGenerator { + i := newInterfaceGenerator(t.spec, t.recv, fset) switch ty := t.spec.Type.(type) { case *ast.StructType: i.validateStruct(t.spec, ty) @@ -376,8 +406,8 @@ func (g *Generator) generateOne(t marshallableType, fset *token.FileSet) *interf // generateOneTestSuite generates a test suite for the automatically generated // implementations type t. -func (g *Generator) generateOneTestSuite(t marshallableType) *testGenerator { - i := newTestGenerator(t.spec) +func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { + i := newTestGenerator(t.spec, t.recv) i.emitTests(t.slice) return i } @@ -417,7 +447,15 @@ func (g *Generator) Run() error { for i, a := range asts { // Collect type declarations marked for code generation and generate // Marshallable interfaces. + var sortedTypes []*marshallableType for _, t := range g.collectMarshallableTypes(a, fsets[i]) { + sortedTypes = append(sortedTypes, t) + } + sort.Slice(sortedTypes, func(x, y int) bool { + // Sort by type name, which should be unique within a package. + return sortedTypes[x].spec.Name.String() < sortedTypes[y].spec.Name.String() + }) + for _, t := range sortedTypes { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. for ref := range impl.ms { diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 36447b86b..65f5ea34d 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -54,10 +54,10 @@ func (g *interfaceGenerator) typeName() string { } // newinterfaceGenerator creates a new interface generator. -func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator { +func newInterfaceGenerator(t *ast.TypeSpec, r string, fset *token.FileSet) *interfaceGenerator { g := &interfaceGenerator{ t: t, - r: receiverName(t), + r: r, f: fset, is: make(map[string]struct{}), ms: make(map[string]struct{}), diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index 631295373..6cf00843f 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -53,10 +53,10 @@ type testGenerator struct { decl *importStmt } -func newTestGenerator(t *ast.TypeSpec) *testGenerator { +func newTestGenerator(t *ast.TypeSpec, r string) *testGenerator { g := &testGenerator{ t: t, - r: receiverName(t), + r: r, imports: newImportTable(), } diff --git a/tools/images.mk b/tools/images.mk new file mode 100644 index 000000000..2003da5bd --- /dev/null +++ b/tools/images.mk @@ -0,0 +1,169 @@ +#!/usr/bin/make -f + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +## +## Docker image targets. +## +## Images used by the tests must also be built and available locally. +## The canonical test targets defined below will automatically load +## relevant images. These can be loaded or built manually via these +## targets. +## +## (*) Note that you may provide an ARCH parameter in order to build +## and load images from an alternate archiecture (using qemu). When +## bazel is run as a server, this has the effect of running an full +## cross-architecture chain, and can produce cross-compiled binaries. +## + +# ARCH is the architecture used for the build. This may be overriden at the +# command line in order to perform a cross-build (in a limited capacity). +ARCH := $(shell uname -m) +ifneq ($(ARCH),$(shell uname -m)) +DOCKER_PLATFORM_ARGS := --platform=$(ARCH) +else +DOCKER_PLATFORM_ARGS := +endif + +# Note that the image prefixes used here must match the image mangling in +# runsc/testutil.MangleImage. Names are mangled in this way to ensure that all +# tests are using locally-defined images (that are consistent and idempotent). +REMOTE_IMAGE_PREFIX ?= gcr.io/gvisor-presubmit +LOCAL_IMAGE_PREFIX ?= gvisor.dev/images +ALL_IMAGES := $(subst /,_,$(subst images/,,$(shell find images/ -name Dockerfile -o -name Dockerfile.$(ARCH) | xargs -n 1 dirname | uniq))) +SUB_IMAGES := $(foreach image,$(ALL_IMAGES),$(if $(findstring _,$(image)),$(image),)) +IMAGE_GROUPS := $(sort $(foreach image,$(SUB_IMAGES),$(firstword $(subst _, ,$(image))))) + +define expand_group = +load-$(1): $$(patsubst $(1)_%, load-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES))) + @ +.PHONY: load-$(1) +push-$(1): $$(patsubst $(1)_%, push-$(1)_%, $$(filter $(1)_%,$$(ALL_IMAGES))) + @ +.PHONY: push-$(1) +endef +$(foreach group,$(IMAGE_GROUPS),$(eval $(call expand_group,$(group)))) + +list-all-images: ## List all images. + @for image in $(ALL_IMAGES); do echo $${image}; done +.PHONY: list-all-images + +load-all-images: ## Load all images. +load-all-images: $(patsubst %,load-%,$(ALL_IMAGES)) +.PHONY: load-all-images + +push-all-images: ## Push all images. +push-all-images: $(patsubst %,push-%,$(ALL_IMAGES)) +.PHONY: push-all-images + +# path and dockerfile are used to extract the relevant path and dockerfile +# (depending on what's available for the given architecture). +path = images/$(subst _,/,$(1)) +dockerfile = $$(if [ -f "$(call path,$(1))/Dockerfile.$(ARCH)" ]; then echo Dockerfile.$(ARCH); else echo Dockerfile; fi) + +# The tag construct is used to memoize the image generated (see README.md). +# This scheme is used to enable aggressive caching in a central repository, but +# ensuring that images will always be sourced using the local files. +tag = $(shell cd images && find $(subst _,/,$(1)) -type f | sort | xargs -n 1 sha256sum | sha256sum - | cut -c 1-16) +remote_image = $(REMOTE_IMAGE_PREFIX)/$(subst _,/,$(1))_$(ARCH) +local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) + +# Include all existing images as targets here. +# +# Note that we use a _ for the tag separator, instead of :, as the latter is +# interpreted by Make, unfortunately. tag_expand expands the generic rules to +# tag-specific targets. These is needed to provide sensible targets for load +# below, with caching. Basically, if there is a rule generated here, then the +# load will be skipped. If there is no load generated here, then the default +# rule for load will kick in. +# +# Note that if this rule does not successfully rule, we will simply have +# additional Docker pull commands that run for all images that are already +# pulled. No real harm done. +EXISTING_IMAGES = $(shell docker images --format '{{.Repository}}_{{.Tag}}' | grep -v '<none>') +define existing_image_rule = +loaded0_$(1)=load-$$(1): tag-$$(1) # Already available. +loaded1_$(1)=.PHONY: load-$$(1) +endef +$(foreach image, $(EXISTING_IMAGES), $(eval $(call existing_image_rule,$(image)))) +define tag_expand_rule = +$(eval $(loaded0_$(call remote_image,$(1))_$(call tag,$(1)))) +$(eval $(loaded1_$(call remote_image,$(1))_$(call tag,$(1)))) +endef +$(foreach image, $(ALL_IMAGES), $(eval $(call tag_expand_rule,$(image)))) + +# tag tags a local image. This applies both the hash-based tag from above to +# ensure that caching works as expected, as well as the "latest" tag that is +# used by the tests. +local_tag = \ + docker tag $(call remote_image,$(1)):$(call tag,$(1)) $(call local_image,$(1)):$(call tag,$(1)) >&2 +latest_tag = \ + docker tag $(call local_image,$(1)):$(call tag,$(1)) $(call local_image,$(1)) >&2 +tag-%: ## Tag a local image. + @$(call header,TAG $*) + @$(call local_tag,$*) && $(call latest_tag,$*) + +# pull forces the image to be pulled. +pull = \ + $(call header,PULL $(1)) && \ + docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$(1)):$(call tag,$(1)) >&2 && \ + $(call local_tag,$(1)) && \ + $(call latest_tag,$(1)) +pull-%: register-cross ## Force a repull of the image. + @$(call pull,$*) + +# rebuild builds the image locally. Only the "remote" tag will be applied. Note +# we need to explicitly repull the base layer in order to ensure that the +# architecture is correct. Note that we use the term "rebuild" here to avoid +# conflicting with the bazel "build" terminology, which is used elsewhere. +rebuild = \ + $(call header,REBUILD $(1)) && \ + (T=$$(mktemp -d) && cp -a $(call path,$(1))/* $$T && \ + $(foreach image,$(shell grep FROM "$(call path,$(1))/$(call dockerfile,$(1))" 2>/dev/null | cut -d' ' -f2),docker pull $(DOCKER_PLATFORM_ARGS) $(image) >&2 &&) \ + docker build $(DOCKER_PLATFORM_ARGS) \ + -f "$$T/$(call dockerfile,$(1))" \ + -t "$(call remote_image,$(1)):$(call tag,$(1))" \ + $$T >&2 && \ + rm -rf $$T) && \ + $(call local_tag,$(1)) && \ + $(call latest_tag,$(1)) +rebuild-%: register-cross ## Force rebuild an image locally. + @$(call rebuild,$*) + +# load will either pull the "remote" or build it locally. This is the preferred +# entrypoint, as it should never fail. The local tag should always be set after +# this returns (either by the pull or the build). +load-%: register-cross ## Pull or build an image locally. + @($(call pull,$*)) || ($(call rebuild,$*)) + +# push pushes the remote image, after either pulling (to validate that the tag +# already exists) or building manually. Note that this generic rule will match +# the fully-expanded remote image tag. +push-%: load-% ## Push a given image. + @docker push $(call remote_image,$*):$(call tag,$*) >&2 + +# register-cross registers the necessary qemu binaries for cross-compilation. +# This may be used by any target that may execute containers that are not the +# native format. Note that this will only apply on the first execution. +register-cross: +ifneq ($(ARCH),$(shell uname -m)) +ifeq (,$(wildcard /proc/sys/fs/binfmt_misc/qemu-*)) + @docker run --rm --privileged multiarch/qemu-user-static --reset --persistent yes >&2 +else + @ +endif +else + @ +endif diff --git a/tools/installers/BUILD b/tools/installers/BUILD index 13d3cc5e0..bbf3c1f85 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -1,4 +1,4 @@ -# Installers for use by the tools/vm_test rules. +# Installers for use by top-level scripts. package( default_visibility = ["//:sandbox"], @@ -14,14 +14,6 @@ sh_binary( ) sh_binary( - name = "images", - srcs = ["images.sh"], - data = [ - "//images", - ], -) - -sh_binary( name = "master", srcs = ["master.sh"], ) diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh index 5520a447c..d28549734 100755 --- a/tools/installers/containerd.sh +++ b/tools/installers/containerd.sh @@ -16,7 +16,7 @@ set -xeo pipefail -declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0} +declare -r CONTAINERD_VERSION=${1:-1.3.0} declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')" declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')" diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 12b8b597c..566e0889e 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -3,6 +3,8 @@ load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib", "nogo_target") package(licenses = ["notice"]) +exports_files(["config-schema.json"]) + nogo_target( name = "target", goarch = select_goarch(), diff --git a/tools/nogo/config-schema.json b/tools/nogo/config-schema.json new file mode 100644 index 000000000..3c25fe221 --- /dev/null +++ b/tools/nogo/config-schema.json @@ -0,0 +1,97 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema", + "definitions": { + "group": { + "type": "object", + "properties": { + "name": { + "description": "The name of the group.", + "type": "string" + }, + "regex": { + "description": "A regular expression for matching paths.", + "type": "string" + }, + "default": { + "description": "Whether the group is enabled by default.", + "type": "boolean" + } + }, + "required": [ + "name", + "regex", + "default" + ], + "additionalProperties": false + }, + "regexlist": { + "description": "A list of regular expressions.", + "oneOf": [ + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "null" + } + ] + }, + "rule": { + "type": "object", + "properties": { + "exclude": { + "description": "A regular expression for paths to exclude.", + "$ref": "#/definitions/regexlist" + }, + "suppress": { + "description": "A regular expression for messages to suppress.", + "$ref": "#/definitions/regexlist" + } + }, + "additionalProperties": false + }, + "ruleList": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "$ref": "#/definitions/rule" + }, + { + "type": "null" + } + ] + } + } + }, + "properties": { + "groups": { + "description": "A definition of all groups.", + "type": "array", + "items": { + "$ref": "#/definitions/group" + }, + "minItems": 1 + }, + "global": { + "description": "A global set of rules.", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/rule" + } + }, + "analyzers": { + "description": "A definition of all groups.", + "type": "object", + "additionalProperties": { + "$ref": "#/definitions/ruleList" + } + } + }, + "required": [ + "groups" + ], + "additionalProperties": false +} diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go index 9cf41b3b0..8be38ca6d 100644 --- a/tools/nogo/filter/main.go +++ b/tools/nogo/filter/main.go @@ -16,6 +16,7 @@ package main import ( + "bytes" "flag" "fmt" "io/ioutil" @@ -76,12 +77,14 @@ func main() { log.Fatalf("unable to read %s: %v", filename, err) } var newConfig nogo.Config // For current file. - if err := yaml.Unmarshal(content, &newConfig); err != nil { + dec := yaml.NewDecoder(bytes.NewBuffer(content)) + dec.SetStrict(true) + if err := dec.Decode(&newConfig); err != nil { log.Fatalf("unable to decode %s: %v", filename, err) } config.Merge(&newConfig) if showConfig { - bytes, err := yaml.Marshal(&newConfig) + content, err := yaml.Marshal(&newConfig) if err != nil { log.Fatalf("error marshalling config: %v", err) } @@ -89,7 +92,7 @@ func main() { if err != nil { log.Fatalf("error marshalling config: %v", err) } - fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes)) + fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(content)) fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes)) } } diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go index f0737d46b..39a13b4af 100644 --- a/tools/parsers/go_parser_test.go +++ b/tools/parsers/go_parser_test.go @@ -34,6 +34,10 @@ func TestParseLine(t *testing.T) { Name: "BenchmarkIperf", Condition: []*bigquery.Condition{ { + Name: "iterations", + Value: "1", + }, + { Name: "GOMAXPROCS", Value: "6", }, @@ -63,6 +67,10 @@ func TestParseLine(t *testing.T) { Name: "BenchmarkRuby", Condition: []*bigquery.Condition{ { + Name: "iterations", + Value: "1", + }, + { Name: "GOMAXPROCS", Value: "6", }, @@ -100,12 +108,14 @@ func TestParseLine(t *testing.T) { } if !cmp.Equal(tc.want, got, nil) { - for _, c := range got.Condition { - t.Logf("Cond: %+v", c) + for i := range got.Condition { + t.Logf("Metric: want: %+v got:%+v", got.Condition[i], tc.want.Condition[i]) } - for _, m := range got.Metric { - t.Logf("Metric: %+v", m) + + for i := range got.Metric { + t.Logf("Metric: want: %+v got:%+v", got.Metric[i], tc.want.Metric[i]) } + t.Fatalf("Compare failed want: %+v got: %+v", tc.want, got) } }) @@ -131,7 +141,7 @@ func TestParseOutput(t *testing.T) { `, numBenchmarks: 2, numMetrics: 1, - numConditions: 1, + numConditions: 2, }, { name: "Ruby", @@ -142,7 +152,7 @@ BenchmarkRuby/server_threads.5 BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 465 requests_per_second.QPS`, numBenchmarks: 2, numMetrics: 3, - numConditions: 2, + numConditions: 3, }, } diff --git a/tools/vm/BUILD b/tools/vm/BUILD deleted file mode 100644 index d95ca6c63..000000000 --- a/tools/vm/BUILD +++ /dev/null @@ -1,63 +0,0 @@ -load("//tools:defs.bzl", "bzl_library", "cc_binary", "gtest") -load("//tools/vm:defs.bzl", "vm_image", "vm_test") - -package( - default_visibility = ["//:sandbox"], - licenses = ["notice"], -) - -sh_binary( - name = "zone", - srcs = ["zone.sh"], -) - -sh_binary( - name = "builder", - srcs = ["build.sh"], -) - -sh_binary( - name = "executer", - srcs = ["execute.sh"], -) - -cc_binary( - name = "test", - testonly = 1, - srcs = ["test.cc"], - linkstatic = 1, - deps = [ - gtest, - "//test/util:test_main", - ], -) - -vm_image( - name = "ubuntu1604", - family = "ubuntu-1604-lts", - project = "ubuntu-os-cloud", - scripts = [ - "//tools/vm/ubuntu1604", - ], -) - -vm_image( - name = "ubuntu1804", - family = "ubuntu-1804-lts", - project = "ubuntu-os-cloud", - scripts = [ - "//tools/vm/ubuntu1804", - ], -) - -vm_test( - name = "vm_test", - shard_count = 2, - targets = [":test"], -) - -bzl_library( - name = "defs_bzl", - srcs = ["defs.bzl"], - visibility = ["//visibility:private"], -) diff --git a/tools/vm/README.md b/tools/vm/README.md deleted file mode 100644 index 1e9859e66..000000000 --- a/tools/vm/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# VM Images & Tests - -All commands in this directory require the `gcloud` project to be set. - -For example: `gcloud config set project gvisor-kokoro-testing`. - -Images can be generated by using the `vm_image` rule. This rule will generate a -binary target that builds an image in an idempotent way, and can be referenced -from other rules. - -For example: - -``` -vm_image( - name = "ubuntu", - project = "ubuntu-1604-lts", - family = "ubuntu-os-cloud", - scripts = [ - "script.sh", - "other.sh", - ], -) -``` - -These images can be built manually by executing the target. The output on -`stdout` will be the image id (in the current project). - -For example: - -``` -$ bazel build :ubuntu -``` - -Images are always named per the hash of all the hermetic input scripts. This -allows images to be memoized quickly and easily. - -The `vm_test` rule can be used to execute a command remotely. This is still -under development however, and will likely change over time. - -For example: - -``` -vm_test( - name = "mycommand", - image = ":ubuntu", - targets = [":test"], -) -``` diff --git a/tools/vm/build.sh b/tools/vm/build.sh deleted file mode 100755 index 752b2b77b..000000000 --- a/tools/vm/build.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script is responsible for building a new GCP image that: 1) has nested -# virtualization enabled, and 2) has been completely set up with the -# image_setup.sh script. This script should be idempotent, as we memoize the -# setup script with a hash and check for that name. - -set -eou pipefail - -# Parameters. -declare -r USERNAME=${USERNAME:-test} -declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud} -declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts} -declare -r ZONE=${ZONE:-us-central1-f} - -# Random names. -declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z) -declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z) -declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z) - -# Hash inputs in order to memoize the produced image. -declare -r SETUP_HASH=$( (echo ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && cat "$@") | sha256sum - | cut -d' ' -f1 | cut -c 1-16) -declare -r IMAGE_NAME=${IMAGE_FAMILY:-image}-${SETUP_HASH} - -# Does the image already exist? Skip the build. -declare -r existing=$(set -x; gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)") -if ! [[ -z "${existing}" ]]; then - echo "${existing}" - exit 0 -fi - -# Standard arguments (applies only on script execution). -declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") - -# gcloud has path errors; is this a result of being a genrule? -export PATH=${PATH:-/bin:/usr/bin:/usr/local/bin} - -# Start a unique instance. Note that this instance will have a unique persistent -# disk as it's boot disk with the same name as the instance. -(set -x; gcloud compute instances create \ - --quiet \ - --image-project "${IMAGE_PROJECT}" \ - --image-family "${IMAGE_FAMILY}" \ - --boot-disk-size "200GB" \ - --zone "${ZONE}" \ - "${INSTANCE_NAME}" >/dev/null) -function cleanup { - (set -x; gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}") -} -trap cleanup EXIT - -# Wait for the instance to become available (up to 5 minutes). -echo -n "Waiting for ${INSTANCE_NAME}" >&2 -declare timeout=300 -declare success=0 -declare internal="" -declare -r start=$(date +%s) -declare -r end=$((${start}+${timeout})) -while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - echo -n "." >&2 - if gcloud compute ssh --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then - success=$((${success}+1)) - elif gcloud compute ssh --internal-ip --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then - success=$((${success}+1)) - internal="--internal-ip" - fi -done - -if [[ "${success}" -eq "0" ]]; then - echo "connect timed out after ${timeout} seconds." >&2 - exit 1 -else - echo "done." >&2 -fi - -# Run the install scripts provided. -for arg; do - (set -x; gcloud compute ssh ${internal} \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${SSH_ARGS[@]}" \ - sudo bash - <"${arg}" >/dev/null) -done - -# Stop the instance; required before creating an image. -(set -x; gcloud compute instances stop --quiet --zone "${ZONE}" "${INSTANCE_NAME}" >/dev/null) - -# Create a snapshot of the instance disk. -(set -x; gcloud compute disks snapshot \ - --quiet \ - --zone "${ZONE}" \ - --snapshot-names="${SNAPSHOT_NAME}" \ - "${INSTANCE_NAME}" >/dev/null) - -# Create the disk image. -(set -x; gcloud compute images create \ - --quiet \ - --source-snapshot="${SNAPSHOT_NAME}" \ - --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \ - "${IMAGE_NAME}" >/dev/null) - -# Finish up. -echo "${IMAGE_NAME}" diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl deleted file mode 100644 index 9af5ad3b4..000000000 --- a/tools/vm/defs.bzl +++ /dev/null @@ -1,202 +0,0 @@ -"""Image configuration. See README.md.""" - -load("//tools:defs.bzl", "default_installer") - -# vm_image_builder is a rule that will construct a shell script that actually -# generates a given VM image. Note that this does not _run_ the shell script -# (although it can be run manually). It will be run manually during generation -# of the vm_image target itself. This level of indirection is used so that the -# build system itself only runs the builder once when multiple targets depend -# on it, avoiding a set of races and conflicts. -def _vm_image_builder_impl(ctx): - # Generate a binary that actually builds the image. - builder = ctx.actions.declare_file(ctx.label.name) - script_paths = [] - for script in ctx.files.scripts: - script_paths.append(script.short_path) - builder_content = "\n".join([ - "#!/bin/bash", - "export ZONE=$(%s)" % ctx.files.zone[0].short_path, - "export USERNAME=%s" % ctx.attr.username, - "export IMAGE_PROJECT=%s" % ctx.attr.project, - "export IMAGE_FAMILY=%s" % ctx.attr.family, - "%s %s" % (ctx.files._builder[0].short_path, " ".join(script_paths)), - "", - ]) - ctx.actions.write(builder, builder_content, is_executable = True) - - # Note that the scripts should only be files, and should not include any - # indirect transitive dependencies. The build script wouldn't work. - return [DefaultInfo( - executable = builder, - runfiles = ctx.runfiles( - files = ctx.files.scripts + ctx.files._builder + ctx.files.zone, - ), - )] - -vm_image_builder = rule( - attrs = { - "_builder": attr.label( - executable = True, - default = "//tools/vm:builder", - cfg = "host", - ), - "username": attr.string(default = "$(whoami)"), - "zone": attr.label( - executable = True, - default = "//tools/vm:zone", - cfg = "host", - ), - "family": attr.string(mandatory = True), - "project": attr.string(mandatory = True), - "scripts": attr.label_list(allow_files = True), - }, - executable = True, - implementation = _vm_image_builder_impl, -) - -# See vm_image_builder above. -def _vm_image_impl(ctx): - # Run the builder to generate our output. - echo = ctx.actions.declare_file(ctx.label.name) - resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "\n".join([ - "set -e", - "image=$(%s)" % ctx.files.builder[0].path, - "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path, - "chmod 0755 %s" % echo.path, - ]), - tools = [ctx.attr.builder], - ) - ctx.actions.run_shell( - tools = resolved_inputs, - outputs = [echo], - progress_message = "Building image...", - execution_requirements = {"local": "true"}, - command = argv, - input_manifests = runfiles_manifests, - ) - - # Return just the echo command. All of the builder runfiles have been - # resolved and consumed in the generation of the trivial echo script. - return [DefaultInfo(executable = echo)] - -_vm_image_test = rule( - attrs = { - "builder": attr.label( - executable = True, - cfg = "host", - ), - }, - test = True, - implementation = _vm_image_impl, -) - -def vm_image(name, **kwargs): - vm_image_builder( - name = name + "_builder", - **kwargs - ) - _vm_image_test( - name = name, - builder = ":" + name + "_builder", - tags = [ - "local", - "manual", - ], - ) - -def _vm_test_impl(ctx): - runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) - - # Note that the remote execution case must actually generate an - # intermediate target in order to collect all the relevant runfiles so that - # they can be copied over for remote execution. - runner_content = "\n".join([ - "#!/bin/bash", - "export ZONE=$(%s)" % ctx.files.zone[0].short_path, - "export USERNAME=%s" % ctx.attr.username, - "export IMAGE=$(%s)" % ctx.files.image[0].short_path, - "export SUDO=%s" % "true" if ctx.attr.sudo else "false", - "%s %s" % ( - ctx.executable.executer.short_path, - " ".join([ - target.files_to_run.executable.short_path - for target in ctx.attr.targets - ]), - ), - "", - ]) - ctx.actions.write(runner, runner_content, is_executable = True) - - # Return with all transitive files. - runfiles = ctx.runfiles( - transitive_files = depset(transitive = [ - depset(target.data_runfiles.files) - for target in ctx.attr.targets - if hasattr(target, "data_runfiles") - ]), - files = ctx.files.executer + ctx.files.zone + ctx.files.image + - ctx.files.targets, - collect_default = True, - collect_data = True, - ) - return [DefaultInfo(executable = runner, runfiles = runfiles)] - -_vm_test = rule( - attrs = { - "image": attr.label( - executable = True, - default = "//tools/vm:ubuntu1804", - cfg = "host", - ), - "executer": attr.label( - executable = True, - default = "//tools/vm:executer", - cfg = "host", - ), - "username": attr.string(default = "$(whoami)"), - "zone": attr.label( - executable = True, - default = "//tools/vm:zone", - cfg = "host", - ), - "sudo": attr.bool(default = True), - "machine": attr.string(default = "n1-standard-1"), - "targets": attr.label_list( - mandatory = True, - allow_empty = False, - cfg = "target", - ), - }, - test = True, - implementation = _vm_test_impl, -) - -def vm_test( - installers = None, - **kwargs): - """Runs the given targets as a remote test. - - Args: - installer: Script to run before all targets. - **kwargs: All test arguments. Should include targets and image. - """ - targets = kwargs.pop("targets", []) - if installers == None: - installers = [ - "//tools/installers:head", - "//tools/installers:images", - ] - targets = installers + targets - if default_installer(): - targets = [default_installer()] + targets - _vm_test( - tags = [ - "local", - "manual", - ], - targets = targets, - local = 1, - **kwargs - ) diff --git a/tools/vm/execute.sh b/tools/vm/execute.sh deleted file mode 100755 index 1f1f3ce01..000000000 --- a/tools/vm/execute.sh +++ /dev/null @@ -1,160 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Required input. -if ! [[ -v IMAGE ]]; then - echo "no image provided: set IMAGE." - exit 1 -fi - -# Parameters. -declare -r USERNAME=${USERNAME:-test} -declare -r KEYNAME=$(mktemp --tmpdir -u key-XXXXXX) -declare -r SSHKEYS=$(mktemp --tmpdir -u sshkeys-XXXXXX) -declare -r INSTANCE_NAME=$(mktemp -u test-XXXXXX | tr A-Z a-z) -declare -r MACHINE=${MACHINE:-n1-standard-1} -declare -r ZONE=${ZONE:-us-central1-f} -declare -r SUDO=${SUDO:-false} - -# Standard arguments (applies only on script execution). -declare -ar SSH_ARGS=("-o" "ConnectTimeout=60" "--") - -# This script is executed as a test rule, which will reset the value of HOME. -# Unfortunately, it is needed to load the gconfig credentials. We will reset -# HOME when we actually execute in the remote environment, defined below. -export HOME=$(eval echo ~$(whoami)) - -# Generate unique keys for this test. -[[ -f "${KEYNAME}" ]] || ssh-keygen -t rsa -N "" -f "${KEYNAME}" -C "${USERNAME}" -cat > "${SSHKEYS}" <<EOF -${USERNAME}:$(cat ${KEYNAME}.pub) -EOF - -# Start a unique instance. This means that we first generate a unique set of ssh -# keys to ensure that only we have access to this instance. Note that we must -# constrain ourselves to Haswell or greater in order to have nested -# virtualization available. -gcloud compute instances create \ - --min-cpu-platform "Intel Haswell" \ - --preemptible \ - --no-scopes \ - --metadata block-project-ssh-keys=TRUE \ - --metadata-from-file ssh-keys="${SSHKEYS}" \ - --machine-type "${MACHINE}" \ - --image "${IMAGE}" \ - --zone "${ZONE}" \ - "${INSTANCE_NAME}" -function cleanup { - gcloud compute instances delete --quiet --zone "${ZONE}" "${INSTANCE_NAME}" -} -trap cleanup EXIT - -# Wait for the instance to become available (up to 5 minutes). -declare timeout=300 -declare success=0 -declare -r start=$(date +%s) -declare -r end=$((${start}+${timeout})) -while [[ "$(date +%s)" -lt "${end}" ]] && [[ "${success}" -lt 3 ]]; do - if gcloud compute ssh --ssh-key-file="${KEYNAME}" --zone "${ZONE}" "${USERNAME}"@"${INSTANCE_NAME}" -- true 2>/dev/null; then - success=$((${success}+1)) - fi -done -if [[ "${success}" -eq "0" ]]; then - echo "connect timed out after ${timeout} seconds." - exit 1 -fi - -# Copy the local directory over. -tar czf - --dereference --exclude=.git . | - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${SSH_ARGS[@]}" \ - tar xzf - - -# Execute the command remotely. -for cmd; do - # Setup relevant environment. - # - # N.B. This is not a complete test environment, but is complete enough to - # provide rudimentary sharding and test output support. - declare -a PREFIX=( "env" ) - if [[ -v TEST_SHARD_INDEX ]]; then - PREFIX+=( "TEST_SHARD_INDEX=${TEST_SHARD_INDEX}" ) - fi - if [[ -v TEST_SHARD_STATUS_FILE ]]; then - SHARD_STATUS_FILE=$(mktemp -u test-shard-status-XXXXXX) - PREFIX+=( "TEST_SHARD_STATUS_FILE=/tmp/${SHARD_STATUS_FILE}" ) - fi - if [[ -v TEST_TOTAL_SHARDS ]]; then - PREFIX+=( "TEST_TOTAL_SHARDS=${TEST_TOTAL_SHARDS}" ) - fi - if [[ -v TEST_TMPDIR ]]; then - REMOTE_TMPDIR=$(mktemp -u test-XXXXXX) - PREFIX+=( "TEST_TMPDIR=/tmp/${REMOTE_TMPDIR}" ) - # Create remotely. - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${SSH_ARGS[@]}" \ - mkdir -p "/tmp/${REMOTE_TMPDIR}" - fi - if [[ -v XML_OUTPUT_FILE ]]; then - TEST_XML_OUTPUT=$(mktemp -u xml-output-XXXXXX) - PREFIX+=( "XML_OUTPUT_FILE=/tmp/${TEST_XML_OUTPUT}" ) - fi - if [[ "${SUDO}" == "true" ]]; then - PREFIX+=( "sudo" "-E" ) - fi - - # Execute the command. - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${SSH_ARGS[@]}" \ - "${PREFIX[@]}" "${cmd}" - - # Collect relevant results. - if [[ -v TEST_SHARD_STATUS_FILE ]]; then - gcloud compute scp \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${SHARD_STATUS_FILE}" \ - "${TEST_SHARD_STATUS_FILE}" 2>/dev/null || true # Allowed to fail. - fi - if [[ -v XML_OUTPUT_FILE ]]; then - gcloud compute scp \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}":/tmp/"${TEST_XML_OUTPUT}" \ - "${XML_OUTPUT_FILE}" 2>/dev/null || true # Allowed to fail. - fi - - # Clean up the temporary directory. - if [[ -v TEST_TMPDIR ]]; then - gcloud compute ssh \ - --ssh-key-file="${KEYNAME}" \ - --zone "${ZONE}" \ - "${USERNAME}"@"${INSTANCE_NAME}" -- \ - "${SSH_ARGS[@]}" \ - rm -rf "/tmp/${REMOTE_TMPDIR}" - fi -done diff --git a/tools/vm/ubuntu1604/10_core.sh b/tools/vm/ubuntu1604/10_core.sh deleted file mode 100755 index 629f7cf7a..000000000 --- a/tools/vm/ubuntu1604/10_core.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Install all essential build tools. -while true; do - if (apt-get update && apt-get install -y \ - make \ - git-core \ - build-essential \ - linux-headers-$(uname -r) \ - pkg-config); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install a recent go toolchain. -if ! [[ -d /usr/local/go ]]; then - wget https://dl.google.com/go/go1.13.5.linux-amd64.tar.gz - tar -xvf go1.13.5.linux-amd64.tar.gz - mv go /usr/local -fi - -# Link the Go binary from /usr/bin; replacing anything there. -(cd /usr/bin && rm -f go && ln -fs /usr/local/go/bin/go go) diff --git a/tools/vm/ubuntu1604/15_gcloud.sh b/tools/vm/ubuntu1604/15_gcloud.sh deleted file mode 100755 index bc2e5eccc..000000000 --- a/tools/vm/ubuntu1604/15_gcloud.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Install all essential build tools. -while true; do - if (apt-get update && apt-get install -y \ - apt-transport-https \ - ca-certificates \ - gnupg); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Add gcloud repositories. -echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | \ - tee -a /etc/apt/sources.list.d/google-cloud-sdk.list - -# Add the appropriate key. -curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \ - apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - - -# Install the gcloud SDK. -while true; do - if (apt-get update && apt-get install -y google-cloud-sdk); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done diff --git a/tools/vm/ubuntu1604/20_bazel.sh b/tools/vm/ubuntu1604/20_bazel.sh deleted file mode 100755 index bb7afa676..000000000 --- a/tools/vm/ubuntu1604/20_bazel.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -declare -r BAZEL_VERSION=2.0.0 - -# Install bazel dependencies. -while true; do - if (apt-get update && apt-get install -y \ - openjdk-8-jdk-headless \ - unzip); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Use the release installer. -curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh -rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh diff --git a/tools/vm/ubuntu1604/30_docker.sh b/tools/vm/ubuntu1604/30_docker.sh deleted file mode 100755 index d393133e4..000000000 --- a/tools/vm/ubuntu1604/30_docker.sh +++ /dev/null @@ -1,64 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Add dependencies. -while true; do - if (apt-get update && apt-get install -y \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install the key. -curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add - - -# Add the repository. -add-apt-repository \ - "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) \ - stable" - -# Install docker. -while true; do - if (apt-get update && apt-get install -y \ - docker-ce \ - docker-ce-cli \ - containerd.io); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Enable experimental features, for cross-building aarch64 images. -# Enable Docker IPv6. -cat > /etc/docker/daemon.json <<EOF -{ - "experimental": true, - "fixed-cidr-v6": "2001:db8:1::/64", - "ipv6": true -} -EOF diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh deleted file mode 100755 index d3b96c9ad..000000000 --- a/tools/vm/ubuntu1604/40_kokoro.sh +++ /dev/null @@ -1,72 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Declare kokoro's required public keys. -declare -r ssh_public_keys=( - "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com" - "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com" -) - -# Install dependencies. -while true; do - if (apt-get update && apt-get install -y \ - rsync \ - coreutils \ - python-psutil \ - qemu-kvm \ - python-pip \ - python3-pip \ - zip); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# junitparser is used to merge junit xml files. -pip install --no-cache-dir junitparser - -# We need a kbuilder user, which may already exist. -useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true - -# We need to provision appropriate keys. -mkdir -p ~kbuilder/.ssh -(IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys -chmod 0600 ~kbuilder/.ssh/authorized_keys -chown -R kbuilder ~kbuilder/.ssh - -# Give passwordless sudo access. -cat > /etc/sudoers.d/kokoro <<EOF -kbuilder ALL=(ALL) NOPASSWD:ALL -EOF - -# Ensure we can run Docker without sudo. -usermod -aG docker kbuilder - -# Ensure that we can access kvm. -usermod -aG kvm kbuilder - -# Ensure that /tmpfs exists and is writable by kokoro. -# -# Note that kokoro will typically attach a second disk (sdb) to the instance -# that is used for the /tmpfs volume. In the future we could setup an init -# script that formats and mounts this here; however, we don't expect our build -# artifacts to be that large. -mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY diff --git a/tools/vm/ubuntu1604/BUILD b/tools/vm/ubuntu1604/BUILD deleted file mode 100644 index ab1df0c4c..000000000 --- a/tools/vm/ubuntu1604/BUILD +++ /dev/null @@ -1,7 +0,0 @@ -package(licenses = ["notice"]) - -filegroup( - name = "ubuntu1604", - srcs = glob(["*.sh"]), - visibility = ["//:sandbox"], -) diff --git a/tools/vm/ubuntu1804/BUILD b/tools/vm/ubuntu1804/BUILD deleted file mode 100644 index 0c8856dde..000000000 --- a/tools/vm/ubuntu1804/BUILD +++ /dev/null @@ -1,7 +0,0 @@ -package(licenses = ["notice"]) - -alias( - name = "ubuntu1804", - actual = "//tools/vm/ubuntu1604", - visibility = ["//:sandbox"], -) diff --git a/tools/vm/zone.sh b/tools/vm/zone.sh deleted file mode 100755 index 79569fb19..000000000 --- a/tools/vm/zone.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -exec gcloud config get-value compute/zone diff --git a/tools/yamltest/BUILD b/tools/yamltest/BUILD new file mode 100644 index 000000000..475b3badd --- /dev/null +++ b/tools/yamltest/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "yamltest", + srcs = ["main.go"], + visibility = ["//visibility:public"], + deps = [ + "@com_github_xeipuuv_gojsonschema//:go_default_library", + "@in_gopkg_yaml_v2//:go_default_library", + ], +) diff --git a/tools/yamltest/defs.bzl b/tools/yamltest/defs.bzl new file mode 100644 index 000000000..fd04f947d --- /dev/null +++ b/tools/yamltest/defs.bzl @@ -0,0 +1,41 @@ +"""Tools for testing yaml files against schemas.""" + +def _yaml_test_impl(ctx): + """Implementation for yaml_test.""" + runner = ctx.actions.declare_file(ctx.label.name) + ctx.actions.write(runner, "\n".join([ + "#!/bin/bash", + "set -euo pipefail", + "%s -schema=%s -- %s" % ( + ctx.files._tool[0].short_path, + ctx.files.schema[0].short_path, + " ".join([f.short_path for f in ctx.files.srcs]), + ), + ]), is_executable = True) + return [DefaultInfo( + runfiles = ctx.runfiles(files = ctx.files._tool + ctx.files.schema + ctx.files.srcs), + executable = runner, + )] + +yaml_test = rule( + implementation = _yaml_test_impl, + doc = "Tests a yaml file against a schema.", + attrs = { + "srcs": attr.label_list( + doc = "The input yaml files.", + mandatory = True, + allow_files = True, + ), + "schema": attr.label( + doc = "The schema file in JSON schema format.", + allow_single_file = True, + mandatory = True, + ), + "_tool": attr.label( + executable = True, + cfg = "host", + default = Label("//tools/yamltest:yamltest"), + ), + }, + test = True, +) diff --git a/tools/yamltest/main.go b/tools/yamltest/main.go new file mode 100644 index 000000000..88271fb66 --- /dev/null +++ b/tools/yamltest/main.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary yamltest does strict yaml parsing and validation. +package main + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "os" + + "github.com/xeipuuv/gojsonschema" + yaml "gopkg.in/yaml.v2" +) + +func fixup(v interface{}) (interface{}, error) { + switch x := v.(type) { + case map[interface{}]interface{}: + // Coerse into a string-based map, required for yaml. + strMap := make(map[string]interface{}) + for k, v := range x { + strK, ok := k.(string) + if !ok { + // This cannot be converted to JSON at all. + return nil, fmt.Errorf("invalid key %T in (%#v)", k, x) + } + fv, err := fixup(v) + if err != nil { + return nil, fmt.Errorf(".%s%w", strK, err) + } + strMap[strK] = fv + } + return strMap, nil + case []interface{}: + for i := range x { + fv, err := fixup(x[i]) + if err != nil { + return nil, fmt.Errorf("[%d]%w", i, err) + } + x[i] = fv + } + return x, nil + default: + return v, nil + } +} + +func loadFile(filename string) (gojsonschema.JSONLoader, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + dec := yaml.NewDecoder(f) + dec.SetStrict(true) + var object interface{} + if err := dec.Decode(&object); err != nil { + return nil, err + } + fixedObject, err := fixup(object) // For serialization. + if err != nil { + return nil, err + } + bytes, err := json.Marshal(fixedObject) + if err != nil { + return nil, err + } + return gojsonschema.NewStringLoader(string(bytes)), nil +} + +var schema = flag.String("schema", "", "path to JSON schema file.") + +func main() { + flag.Parse() + if *schema == "" || len(flag.Args()) == 0 { + flag.Usage() + os.Exit(2) + } + + // Construct our schema loader. + schemaLoader := gojsonschema.NewReferenceLoader(fmt.Sprintf("file://%s", *schema)) + + // Parse all documents. + allErrors := make(map[string][]error) + for _, filename := range flag.Args() { + // Record the filename with an empty slice for below, where + // we will emit all files (even those without any errors). + allErrors[filename] = nil + documentLoader, err := loadFile(filename) + if err != nil { + allErrors[filename] = append(allErrors[filename], err) + continue + } + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + allErrors[filename] = append(allErrors[filename], err) + continue + } + for _, desc := range result.Errors() { + allErrors[filename] = append(allErrors[filename], errors.New(desc.String())) + } + } + + // Print errors in yaml format. + totalErrors := 0 + for filename, errs := range allErrors { + totalErrors += len(errs) + if len(errs) == 0 { + fmt.Fprintf(os.Stderr, "%s: ✓\n", filename) + continue + } + fmt.Fprintf(os.Stderr, "%s:\n", filename) + for _, err := range errs { + fmt.Fprintf(os.Stderr, "- %s\n", err) + } + } + if totalErrors != 0 { + os.Exit(1) + } +} diff --git a/website/blog/README.md b/website/blog/README.md new file mode 100644 index 000000000..e1d685288 --- /dev/null +++ b/website/blog/README.md @@ -0,0 +1,62 @@ +# gVisor blog + +The gVisor blog is owned and run by the gVisor team. + +## Contact + +Reach out to us on [gitter](https://gitter.im/gvisor/community) or the +[mailing list](https://groups.google.com/forum/#!forum/gvisor-users) if you +would like to write a blog post. + +## Submit a Post + +Anyone can write a blog post and submit it for review. Purely commercial content +or vendor pitches are not allowed. Please refer to the +[blog guidelines](#blog-guidelines) for more guidance about content is that +allowed. + +To submit a blog post, follow the steps below. + +1. [Sign the Contributor License Agreements](https://gvisor.dev/contributing/) + if you have not yet done so. +1. Familiarize yourself with the Markdown format for the + [existing blog posts](https://github.com/google/gvisor/tree/master/website/blog). +1. Write your blog post in a text editor of your choice. +1. (Optional) If you need help with markdown, check out + [StakEdit](https://stackedit.io/app#) or read + [Jekyll's formatting reference](https://jekyllrb.com/docs/posts/#creating-posts) + for more information. +1. Click **Add file** > **Create new file**. +1. Paste your content into the editor and save it. Name the file in the + following way: *[BLOG] Your proposed title* , but don’t put the date in the + file name. The blog reviewers will work with you on the final file name, and + the date on which the blog will be published. +1. When you save the file, GitHub will walk you through the pull request (PR) + process. +1. Send us a message on [gitter](https://gitter.im/gvisor/community) with a + link to your recently created PR. +1. A reviewer will be assigned to the pull request. They check your submission, + and work with you on feedback and final details. When the pull request is + approved, the blog will be scheduled for publication. + +### Blog Guidelines {#blog-guidelines} + +#### Suitable content: + +- **Original content only** +- gVisor features or project updates +- Tutorials and demos +- Use cases +- Content that is specific to a vendor or platform about gVisor installation + and use + +#### Unsuitable Content: + +- Blogs with no content relevant to gVisor +- Vendor pitches + +## Review Process + +Each blog post should be approved by at least one person on the team. Once all +of the review comments have been addressed and approved, a member of the team +will schedule publication of the blog post. diff --git a/website/blog/index.html b/website/blog/index.html index 5c67c95fc..272917fc4 100644 --- a/website/blog/index.html +++ b/website/blog/index.html @@ -20,3 +20,8 @@ pagination: {% if paginator.total_pages > 1 %} {% include paginator.html %} {% endif %} + +<hr> + +If you would like to contribute to the gVisor blog check out the +<a href="https://github.com/google/gvisor/tree/master/website/blog">instructions</a>. |