diff options
699 files changed, 43029 insertions, 10010 deletions
@@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# RBE requires a strong hash function, such as SHA256. +startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 + # Build with C++17. build --cxxopt=-std=c++17 @@ -20,13 +23,25 @@ build --stamp --workspace_status_command tools/workspace_status.sh # Enable remote execution so actions are performed on the remote systems. build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:remote --bes_backend=buildeventservice.googleapis.com +build:remote --bes_results_url="https://source.cloud.google.com/results/invocations" +build:remote --bes_timeout=600s build:remote --project_id=gvisor-rbe build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance +build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:remote3 --project_id=gvisor-rbe +build:remote3 --bes_backend=buildeventservice.googleapis.com +build:remote3 --bes_results_url="https://source.cloud.google.com/results/invocations" +build:remote3 --bes_timeout=600s +build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance + # Enable authentication. This will pick up application default credentials by # default. You can use --google_credentials=some_file.json to use a service # account credential instead. build:remote --google_default_credentials=true build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" +build:remote3 --google_default_credentials=true +build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. @@ -35,10 +50,15 @@ build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-defa build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain -build:remote --jobs=50 +build:remote --jobs=100 build:remote --remote_timeout=3600 -# RBE requires a strong hash function, such as SHA256. -startup --host_jvm_args=-Dbazel.DigestFunction=SHA256 +build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3 +build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 +build:remote3 --crosstool_top=@rbe_default//cc:toolchain +build:remote3 --jobs=100 +build:remote3 --remote_timeout=3600 # Set flags for uploading to BES in order to view results in the Bazel Build # Results UI. diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 10c86f5cd..18b805c15 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,6 +11,13 @@ jobs: generate: runs-on: ubuntu-latest steps: + - id: setup + run: | + if ! [[ -z "${{ secrets.GO_TOKEN }}" ]]; then + echo ::set-output has_token=true + else + echo ::set-output has_token=false + fi - run: | jq -nc '{"state": "pending", "context": "go tests"}' | \ curl -sL -X POST -d @- \ @@ -19,12 +26,12 @@ jobs: "${{ github.event.pull_request.statuses_url }}" if: github.event_name == 'pull_request' - uses: actions/checkout@v2 - if: github.event_name == 'push' + if: github.event_name == 'push' && steps.setup.outputs.has_token == 'true' with: fetch-depth: 0 token: '${{ secrets.GO_TOKEN }}' - uses: actions/checkout@v2 - if: github.event_name == 'pull_request' + if: github.event_name == 'pull_request' || steps.setup.outputs.has_token != 'true' with: fetch-depth: 0 - uses: actions/setup-go@v2 diff --git a/.github/workflows/issue_reviver.yml b/.github/workflows/issue_reviver.yml index 5e0254111..e68e15270 100644 --- a/.github/workflows/issue_reviver.yml +++ b/.github/workflows/issue_reviver.yml @@ -4,11 +4,13 @@ on: - cron: '0 0 * * *' jobs: - label: + issue_reviver: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + if: github.repository == "google/gvisor" - run: make run TARGETS="//tools/issue_reviver" + if: github.repository == "google/gvisor" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_REPOSITORY: ${{ github.repository }} diff --git a/.travis.yml b/.travis.yml index 4cb202b7d..1d955b05d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,27 @@ language: shell dist: xenial +git: + clone: false # Clone manually in before_install +before_install: + - set -e -o pipefail + - | + if [ "${TRAVIS_PULL_REQUEST}" = false ]; then + # This is not a PR build, fetch and checkout the commit being tested + git clone -q --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}" + cd "${TRAVIS_REPO_SLUG}" + git fetch origin "${TRAVIS_COMMIT}" --depth 1 + git checkout -qf "${TRAVIS_COMMIT}" + else + # This is a PR build, simulate +refs/pull/{num}/merge. + # We can do that by fetching +refs/pull/{num}/head and cherry picking it + # onto the target branch. + git clone -q --branch "${TRAVIS_BRANCH}" --depth 1 "https://github.com/${TRAVIS_REPO_SLUG}.git" "${TRAVIS_REPO_SLUG}" + cd "${TRAVIS_REPO_SLUG}" + git fetch origin "+refs/pull/${TRAVIS_PULL_REQUEST}/head" --depth 1 + git config --global user.email "$(git log -1 FETCH_HEAD --pretty="%cE")" + git config --global user.name "$(git log -1 FETCH_HEAD --pretty="%aN")" + git cherry-pick --strategy=recursive -X theirs --keep-redundant-commits FETCH_HEAD + fi cache: directories: - /home/travis/.cache/bazel/ @@ -8,8 +30,10 @@ services: - docker jobs: include: - - os: linux - arch: amd64 + # AMD64 builds are tested on kokoro, so don't run them in travis to save + # capacity for arm64 builds. + # - os: linux + # arch: amd64 - os: linux arch: arm64 script: @@ -69,7 +69,10 @@ go_path( name = "gopath", mode = "link", deps = [ + # Main binary. "//runsc", + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", # Packages that are not dependencies of //runsc. "//pkg/sentry/kernel/memevent", diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3f8f4c985..89180eb3f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -37,10 +37,8 @@ Dependencies can be added by using `go mod get`. In order to keep the ### Coding Guidelines -All Go code should conform to the [Go style guidelines][gostyle]. C++ code -should conform to the [Google C++ Style Guide][cppstyle] and the guidelines -described for tests. Note that code may be automatically formatted per the -guidelines when merged. +All code should comply with the [style guide](g3doc/style.md). Note that code +may be automatically formatted per the guidelines when merged. As a secure runtime, we need to maintain the safety of all of code included in gVisor. The following rules help mitigate issues. @@ -125,9 +123,7 @@ Contributions made by corporations are covered by a different agreement than the one above, the [Software Grant and Corporate Contributor License Agreement][gccla]. -[cppstyle]: https://google.github.io/styleguide/cppguide.html [gcla]: https://cla.developers.google.com/about/google-individual [gccla]: https://cla.developers.google.com/about/google-corporate [github]: https://github.com/google/gvisor/compare [gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev -[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments @@ -200,3 +200,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +------------------ + +Some files carry the following license, noted at the top of each file: + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + THE SOFTWARE.
\ No newline at end of file @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Helpful pretty-printer. +MAKEBANNER := \033[1;34mmake\033[0m +submake = echo -e '$(MAKEBANNER) $1' >&2; $(MAKE) $1 + # Described below. OPTIONS := STARTUP_OPTIONS := @@ -85,7 +89,7 @@ endif ## define images $(1)-%: ## Image tool: $(1) a given image (also may use 'all-images'). - @$(MAKE) -C images $$@ + @$(call submake,-C images $$@) endef rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'. $(eval $(call images,rebuild)) @@ -96,7 +100,7 @@ $(eval $(call images,push)) load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'. $(eval $(call images,load)) list-images: ## List all available images. - @$(MAKE) -C images $$@ + @$(call submake, -C images $$@) ## ## Canonical build and test targets. @@ -106,21 +110,132 @@ list-images: ## List all available images. ## new subsystem or workflow, consider adding a new target here. ## runsc: ## Builds the runsc binary. - @$(MAKE) build TARGETS="//runsc" + @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc") .PHONY: runsc +debian: ## Builds the debian packages. + @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc:runsc-debian") +.PHONY: debian + smoke-test: ## Runs a simple smoke test after build runsc. - @$(MAKE) run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true" + @$(call submake,run DOCKER_PRIVILEGED="" ARGS="--alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do true") .PHONY: smoke-tests -unit-tests: ## Runs all unit tests in pkg runsc and tools. - @$(MAKE) test OPTIONS="pkg/... runsc/... tools/..." -.PHONY: unit-tests +unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc. + @$(call submake,test TARGETS="pkg/... runsc/... tools/... benchmarks/... benchmarks/runner:runner_test") -tests: ## Runs all local ptrace system call tests. - @$(MAKE) test OPTIONS="--test_tag_filters runsc_ptrace test/syscalls/..." +tests: ## Runs all unit tests and syscall tests. +tests: unit-tests + @$(call submake,test TARGETS="test/syscalls/...") .PHONY: tests + +integration-tests: ## Run all standard integration tests. +integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests +integration-tests: do-tests kvm-tests root-tests containerd-tests +.PHONY: integration-tests + +network-tests: ## Run all networking integration tests. +network-tests: iptables-tests packetdrill-tests packetimpact-tests +.PHONY: network-tests + +# Standard integration targets. +INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test + +syscall-%-tests: + @$(call submake,test OPTIONS="--test_tag_filters runsc_$* test/syscalls/...") + +syscall-native-tests: + @$(call submake,test OPTIONS="--test_tag_filters native test/syscalls/...") +.PHONY: syscall-native-tests + +syscall-tests: ## Run all system call tests. +syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests +.PHONY: syscall-tests + +%-runtime-tests: load-runtimes_% + @$(call submake,install-test-runtime) + @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + +do-tests: runsc + @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") + @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true") + @$(call submake,sudo TARGETS="//runsc" ARGS="do true") +.PHONY: do-tests + +simple-tests: unit-tests # Compatibility target. +.PHONY: simple-tests + +IMAGE_FILTER := HelloWorld\|Httpd\|Ruby\|Stdio +INTEGRATION_FILTER := Life\|Pause\|Connect\|JobControl\|Overlay\|Exec\|DirCreation/root + +docker-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="vfs1") + @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)") + @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=$(IMAGE_FILTER)\|$(INTEGRATION_FILTER)" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: docker-tests + +overlay-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="overlay" ARGS="--overlay") + @$(call submake,test-runtime RUNTIME="overlay" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: overlay-tests + +swgso-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false") + @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: swgso-tests + +hostnet-tests: load-basic-images + @$(call submake,install-test-runtime RUNTIME="hostnet" ARGS="--network=host") + @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: hostnet-tests + +kvm-tests: load-basic-images + @(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm + @if ! [[ -w /dev/kvm ]]; then sudo chmod a+rw /dev/kvm; fi + @$(call submake,test TARGETS="//pkg/sentry/platform/kvm:kvm_test") + @$(call submake,install-test-runtime RUNTIME="kvm" ARGS="--platform=kvm") + @$(call submake,test-runtime RUNTIME="kvm" TARGETS="$(INTEGRATION_TARGETS)") +.PHONY: kvm-tests + +iptables-tests: load-iptables + @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test") + @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") + @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") +.PHONY: iptables-tests + +packetdrill-tests: load-packetdrill + @$(call submake,install-test-runtime RUNTIME="packetdrill") + @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')") +.PHONY: packetdrill-tests + +packetimpact-tests: load-packetimpact + @$(call submake,install-test-runtime RUNTIME="packetimpact") + @$(call submake,test-runtime RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')") +.PHONY: packetimpact-tests + +root-tests: load-basic-images + @$(call submake,install-test-runtime) + @$(call submake,sudo TARGETS="//test/root:root_test" ARGS="-test.v") +.PHONY: test-root + +# Specific containerd version tests. +containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd install-test-runtime + @CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd" + @$(MAKE) sudo TARGETS="tools/installers:shim" + @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="-test.v" + +# Note that we can't run containerd-test-1.1.8 tests here. +# +# Containerd 1.1.8 should work, but because of a bug in loading images locally +# (https://github.com/kubernetes-sigs/cri-tools/issues/421), we are unable to +# actually drive the tests. The v1 API is tested exclusively through 1.2.13. +containerd-tests: ## Runs all supported containerd version tests. +containerd-tests: containerd-test-1.2.13 +containerd-tests: containerd-test-1.3.4 +containerd-tests: containerd-test-1.4.0-beta.0 + ## ## Website & documentation helpers. ## @@ -138,7 +253,7 @@ WEBSITE_PROJECT := gvisordev WEBSITE_REGION := us-central1 website-build: load-jekyll ## Build the site image locally. - @$(MAKE) run TARGETS="//website:website" + @$(call submake,run TARGETS="//website:website") .PHONY: website-build website-server: website-build ## Run a local server for development. @@ -189,8 +304,8 @@ $(RELEASE_KEY): release: $(RELEASE_KEY) ## Builds a release. @mkdir -p $(RELEASE_ROOT) @T=$$(mktemp -d /tmp/release.XXXXXX); \ - $(MAKE) copy TARGETS="runsc" DESTINATION=$$T && \ - $(MAKE) copy TARGETS="runsc:runsc-debian" DESTINATION=$$T && \ + $(call submake,copy TARGETS="runsc" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="runsc:runsc-debian" DESTINATION=$$T) && \ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \ rc=$$?; rm -rf $$T; exit $$rc .PHONY: release @@ -213,42 +328,47 @@ tag: ## Creates and pushes a release tag. ## ifeq (,$(BRANCH_NAME)) RUNTIME := runsc -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/runsc +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) else RUNTIME := $(BRANCH_NAME) -RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(BRANCH_NAME) +RUNTIME_DIR := $(shell dirname $(shell mktemp -u))/$(RUNTIME) endif RUNTIME_BIN := $(RUNTIME_DIR)/runsc RUNTIME_LOG_DIR := $(RUNTIME_DIR)/logs RUNTIME_LOGS := $(RUNTIME_LOG_DIR)/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND% dev: ## Installs a set of local runtimes. Requires sudo. - @$(MAKE) refresh ARGS="--net-raw" - @$(MAKE) configure RUNTIME="$(RUNTIME)" ARGS="--net-raw" - @$(MAKE) configure RUNTIME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets" - @$(MAKE) configure RUNTIME="$(RUNTIME)-p" ARGS="--net-raw --profile" - @$(MAKE) configure RUNTIME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2" + @$(call submake,refresh ARGS="--net-raw") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)" ARGS="--net-raw") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-d" ARGS="--net-raw --debug --strace --log-packets") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-p" ARGS="--net-raw --profile") + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)-vfs2-d" ARGS="--net-raw --debug --strace --log-packets --vfs2") @sudo systemctl restart docker .PHONY: dev -refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'test-install' first. +refresh: ## Refreshes the runtime binary (for development only). Must have called 'dev' or 'install-test-runtime' first. @mkdir -p "$(RUNTIME_DIR)" - @$(MAKE) copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)" && chmod 0755 "$(RUNTIME_BIN)" + @$(call submake,copy TARGETS=runsc DESTINATION="$(RUNTIME_BIN)") .PHONY: install -test-install: ## Installs the runtime for testing. Requires sudo. - @$(MAKE) refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)" - @$(MAKE) configure +install-test-runtime: ## Installs the runtime for testing. Requires sudo. + @$(call submake,refresh ARGS="--net-raw --TESTONLY-test-name-env=RUNSC_TEST_NAME --debug --strace --log-packets $(ARGS)") + @$(call submake,configure RUNTIME_NAME=runsc) + @$(call submake,configure RUNTIME_NAME="$(RUNTIME)") @sudo systemctl restart docker -.PHONY: install-test + @if [[ -f /etc/docker/daemon.json ]]; then \ + sudo chmod 0755 /etc/docker && \ + sudo chmod 0644 /etc/docker/daemon.json; \ + fi +.PHONY: install-test-runtime -configure: ## Configures a single runtime. Requires sudo. Typically called from dev or test-install. - @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) - @echo "Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" - @echo "Logs are in: $(RUNTIME_LOG_DIR)" +configure: ## Configures a single runtime. Requires sudo. Typically called from dev or install-test-runtime. + @sudo sudo "$(RUNTIME_BIN)" install --experimental=true --runtime="$(RUNTIME_NAME)" -- --debug-log "$(RUNTIME_LOGS)" $(ARGS) + @echo -e "$(INFO) Installed runtime \"$(RUNTIME)\" @ $(RUNTIME_BIN)" + @echo -e "$(INFO) Logs are in: $(RUNTIME_LOG_DIR)" @sudo rm -rf "$(RUNTIME_LOG_DIR)" && mkdir -p "$(RUNTIME_LOG_DIR)" .PHONY: configure test-runtime: ## A convenient wrapper around test that provides the runtime argument. Target must still be provided. - @$(MAKE) test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)" -.PHONY: runtime-test + @$(call submake,test OPTIONS="$(OPTIONS) --test_arg=--runtime=$(RUNTIME)") +.PHONY: test-runtime @@ -113,7 +113,6 @@ See [SECURITY.md](SECURITY.md). See [Contributing.md](CONTRIBUTING.md). [bazel]: https://bazel.build -[community]: https://gvisor.googlesource.com/community [docker]: https://www.docker.com [gvisor-users-list]: https://groups.google.com/forum/#!forum/gvisor-users [gvisor-dev]: https://gvisor.dev diff --git a/SECURITY.md b/SECURITY.md index 82cd0efb8..a96843895 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -7,5 +7,4 @@ prompt response, typically within 48 hours. Policies for security list access, vulnerability embargo, and vulnerability disclosure are outlined in the [governance policy](GOVERNANCE.md). -[community]: https://gvisor.googlesource.com/community [gvisor-security-list]: https://groups.google.com/forum/#!forum/gvisor-security @@ -4,11 +4,11 @@ load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") # Bazel/starlark utilities. http_archive( name = "bazel_skylib", + sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", ], - sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44", ) load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") @@ -42,6 +42,28 @@ http_archive( ], ) +http_archive( + name = "io_bazel_rules_go_bazel3", # To replace the above. + patch_args = ["-p1"], + patches = [ + "//tools/nogo:io_bazel_rules_go-visibility.patch", + ], + sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", + ], +) + +http_archive( + name = "bazel_gazelle_bazel3", # To replace the above. + sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", + ], +) + load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() @@ -52,9 +74,6 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") gazelle_dependencies() -# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories" -# block below once 1876 is fixed. -# # The com_google_protobuf repository below would trigger downloading a older # version of org_golang_x_sys. If putting this repository statment in a place # after that of the com_google_protobuf, this statement will not work as @@ -126,6 +145,16 @@ http_archive( ], ) +http_archive( + name = "bazel_toolchains_bazel3", # To replace the above. + sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48", + strip_prefix = "bazel-toolchains-3.1.1", + urls = [ + "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz", + ], +) + # Creates a default toolchain config for RBE. load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig") @@ -159,12 +188,64 @@ load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps") grpc_extra_deps() -# External repositories, in sorted order. +# System Call test dependencies. +http_archive( + name = "com_google_absl", + sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", + strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", + urls = [ + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", + ], +) + +http_archive( + name = "com_google_googletest", + sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", + strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", + urls = [ + "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", + "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", + ], +) + +http_archive( + name = "com_google_benchmark", + sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", + strip_prefix = "benchmark-1.5.0", + urls = [ + "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", + "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", + ], +) + +# External Go repositories. +# +# Unfortunately, gazelle will automatically parse go modules in the +# repositories and generate new go_repository stanzas. These may not respect +# pins that we have in go.mod or below. So order actually matters here. + +go_repository( + name = "com_github_sirupsen_logrus", + importpath = "github.com/sirupsen/logrus", + replace = "github.com/Sirupsen/logrus", + sum = "h1:cWjBmzJnL1sO88XdqJYmq7aiWClqXIQQMJ3Utgy1f+I=", + version = "v1.4.2", +) + +go_repository( + name = "com_github_containerd_containerd", + build_file_proto_mode = "disable", + importpath = "github.com/containerd/containerd", + sum = "h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI=", + version = "v1.3.4", +) + go_repository( name = "com_github_cenkalti_backoff", importpath = "github.com/cenkalti/backoff", - sum = "h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8=", - version = "v0.0.0-20190506075156-2146c9339422", + sum = "h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4=", + version = "v1.1.1-0.20190506075156-2146c9339422", ) go_repository( @@ -182,38 +263,31 @@ go_repository( ) go_repository( - name = "com_github_google_go-cmp", - importpath = "github.com/google/go-cmp", - sum = "h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ=", - version = "v0.2.0", -) - -go_repository( name = "com_github_google_subcommands", importpath = "github.com/google/subcommands", - sum = "h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI=", - version = "v0.0.0-20190508160503-636abe8753b8", + sum = "h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM=", + version = "v1.0.2-0.20190508160503-636abe8753b8", ) go_repository( name = "com_github_google_uuid", importpath = "github.com/google/uuid", - sum = "h1:rXQlD9GXkjA/PQZhmEaF/8Pj/sJfdZJK7GJG0gkS8I0=", - version = "v0.0.0-20171129191014-dec09d789f3d", + sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=", + version = "v1.0.0", ) go_repository( name = "com_github_kr_pretty", importpath = "github.com/kr/pretty", - sum = "h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=", - version = "v0.2.0", + sum = "h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=", + version = "v0.1.0", ) go_repository( name = "com_github_kr_pty", importpath = "github.com/kr/pty", - sum = "h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw=", - version = "v1.1.1", + sum = "h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M=", + version = "v1.1.4-0.20190131011033-7dc38fb350b1", ) go_repository( @@ -225,15 +299,9 @@ go_repository( go_repository( name = "com_github_mohae_deepcopy", - commit = "c48cc78d482608239f6c4c92a4abd87eb8761c90", importpath = "github.com/mohae/deepcopy", -) - -go_repository( - name = "com_github_opencontainers_runtime-spec", - importpath = "github.com/opencontainers/runtime-spec", - sum = "h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI=", - version = "v0.1.2-0.20171211145439-b2d941ef6a78", + sum = "h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c=", + version = "v0.0.0-20170308212314-bb9b5e7adda9", ) go_repository( @@ -246,65 +314,58 @@ go_repository( go_repository( name = "com_github_vishvananda_netlink", importpath = "github.com/vishvananda/netlink", - sum = "h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q=", - version = "v1.0.1-0.20190318003149-adb577d4a45e", -) - -go_repository( - name = "com_github_vishvananda_netns", - importpath = "github.com/vishvananda/netns", - sum = "h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk=", - version = "v0.0.0-20171111001504-be1fbeda1936", + sum = "h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w=", + version = "v1.0.1-0.20190930145447-2ec5bdc52b86", ) go_repository( name = "org_golang_google_grpc", build_file_proto_mode = "disable", importpath = "google.golang.org/grpc", - sum = "h1:zvIju4sqAGvwKspUQOhwnpcqSbzi7/H6QomNNjTL4sk=", - version = "v1.27.1", + sum = "h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo=", + version = "v1.29.0", ) go_repository( name = "in_gopkg_check_v1", importpath = "gopkg.in/check.v1", - sum = "h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=", - version = "v1.0.0-20190902080502-41f04d3bba15", + sum = "h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=", + version = "v1.0.0-20180628173108-788fd7840127", ) go_repository( name = "org_golang_x_crypto", importpath = "golang.org/x/crypto", - sum = "h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=", - version = "v0.0.0-20190308221718-c2843e01d9a2", + sum = "h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI=", + version = "v0.0.0-20200622213623-75b288015ac9", ) go_repository( name = "org_golang_x_mod", importpath = "golang.org/x/mod", - sum = "h1:p1YOIz9H/mGN8k1XkaV5VFAq9+zhN9Obefv439UwRhI=", - version = "v0.2.1-0.20200224194123-e5e73c1b9c72", + sum = "h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4=", + version = "v0.3.0", ) go_repository( name = "org_golang_x_net", importpath = "golang.org/x/net", - sum = "h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI=", - version = "v0.0.0-20190620200207-3b0461eec859", + sum = "h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4=", + version = "v0.0.0-20200625001655-4c5254603344", ) go_repository( name = "org_golang_x_sync", importpath = "golang.org/x/sync", - sum = "h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=", - version = "v0.0.0-20190423024810-112230192c58", + sum = "h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA=", + version = "v0.0.0-20200625203802-6e8e738ad208", ) go_repository( name = "org_golang_x_text", importpath = "golang.org/x/text", - sum = "h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg=", - version = "v0.3.0", + sum = "h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=", + version = "v0.3.2", ) go_repository( @@ -317,15 +378,15 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:Uglradbb4KfUWaYasZhlsDsGRwHHvRsHoNAEONef0W8=", - version = "v0.0.0-20200131233409-575de47986ce", + sum = "h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE=", + version = "v0.0.0-20200707200213-416e8f4faf8a", ) go_repository( name = "org_golang_x_xerrors", importpath = "golang.org/x/xerrors", - sum = "h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=", - version = "v0.0.0-20190717185122-a985d3407aa7", + sum = "h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=", + version = "v0.0.0-20191204190536-9bdfabe68543", ) go_repository( @@ -338,43 +399,106 @@ go_repository( go_repository( name = "com_github_golang_protobuf", importpath = "github.com/golang/protobuf", - sum = "h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=", - version = "v1.3.1", + sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=", + version = "v1.4.2", ) go_repository( - name = "com_github_google_go-github", - importpath = "github.com/google/go-github", - sum = "h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY=", - version = "v17.0.0", + name = "org_golang_x_oauth2", + importpath = "golang.org/x/oauth2", + sum = "h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=", + version = "v0.0.0-20200107190931-bf48bf16ab8d", ) go_repository( - name = "org_golang_x_oauth2", - importpath = "golang.org/x/oauth2", - sum = "h1:pE8b58s1HRDMi8RDc79m0HISf9D4TzseP40cEA6IGfs=", - version = "v0.0.0-20191202225959-858c2ad4c8b6", + name = "com_github_docker_docker", + importpath = "github.com/docker/docker", + sum = "h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w=", + version = "v1.4.2-0.20191028175130-9e7d5ac5ea55", ) go_repository( - name = "com_github_google_go-querystring", - importpath = "github.com/google/go-querystring", - sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=", + name = "com_github_docker_go_connections", + importpath = "github.com/docker/go-connections", + sum = "h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o=", + version = "v0.3.0", +) + +go_repository( + name = "com_github_pkg_errors", + importpath = "github.com/pkg/errors", + sum = "h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=", + version = "v0.9.1", +) + +go_repository( + name = "com_github_docker_go_units", + importpath = "github.com/docker/go-units", + sum = "h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=", + version = "v0.4.0", +) + +go_repository( + name = "com_github_opencontainers_go_digest", + importpath = "github.com/opencontainers/go-digest", + sum = "h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=", version = "v1.0.0", ) go_repository( - name = "com_google_cloud_go_bigquery", - importpath = "cloud.google.com/go/bigquery", - sum = "h1:K2NyuHRuv15ku6eUpe0DQk5ZykPMnSOnvuVf6IHcjaE=", - version = "v1.5.0", + name = "com_github_docker_distribution", + importpath = "github.com/docker/distribution", + sum = "h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE=", + version = "v2.7.1-0.20190205005809-0d3efadf0154+incompatible", ) go_repository( - name = "org_golang_google_api", - importpath = "google.golang.org/api", - sum = "h1:jz2KixHX7EcCPiQrySzPdnYT7DbINAypCqKZ1Z7GM40=", - version = "v0.20.0", + name = "com_github_davecgh_go_spew", + importpath = "github.com/davecgh/go-spew", + sum = "h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=", + version = "v1.1.1", +) + +go_repository( + name = "com_github_konsorten_go_windows_terminal_sequences", + importpath = "github.com/konsorten/go-windows-terminal-sequences", + sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=", + version = "v1.0.2", +) + +go_repository( + name = "com_github_pmezard_go_difflib", + importpath = "github.com/pmezard/go-difflib", + sum = "h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_stretchr_testify", + importpath = "github.com/stretchr/testify", + sum = "h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=", + version = "v1.4.0", +) + +go_repository( + name = "com_github_opencontainers_image_spec", + importpath = "github.com/opencontainers/image-spec", + sum = "h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI=", + version = "v1.0.1", +) + +go_repository( + name = "com_github_microsoft_go_winio", + importpath = "github.com/Microsoft/go-winio", + sum = "h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA=", + version = "v0.4.15-0.20190919025122-fc70bd9a86b5", +) + +go_repository( + name = "com_github_stretchr_objx", + importpath = "github.com/stretchr/objx", + sum = "h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A=", + version = "v0.1.1", ) go_repository( @@ -387,16 +511,457 @@ go_repository( go_repository( name = "org_uber_go_multierr", importpath = "go.uber.org/multierr", - sum = "h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A=", - version = "v1.5.0", + sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=", + version = "v1.2.0", ) -# BigQuery Dependencies for Benchmarks go_repository( name = "com_google_cloud_go", importpath = "cloud.google.com/go", - sum = "h1:eoz/lYxKSL4CNAiaUJ0ZfD1J3bfMYbU5B3rwM1C1EIU=", - version = "v0.55.0", + sum = "h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo=", + version = "v0.52.1-0.20200122224058-0482b626c726", +) + +go_repository( + name = "io_opencensus_go", + importpath = "go.opencensus.io", + sum = "h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs=", + version = "v0.22.2", +) + +go_repository( + name = "co_honnef_go_tools", + importpath = "honnef.co/go/tools", + sum = "h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=", + version = "v0.0.1-2019.2.3", +) + +go_repository( + name = "com_github_burntsushi_toml", + importpath = "github.com/BurntSushi/toml", + sum = "h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=", + version = "v0.3.1", +) + +go_repository( + name = "com_github_census_instrumentation_opencensus_proto", + importpath = "github.com/census-instrumentation/opencensus-proto", + sum = "h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=", + version = "v0.2.1", +) + +go_repository( + name = "com_github_client9_misspell", + importpath = "github.com/client9/misspell", + sum = "h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI=", + version = "v0.3.4", +) + +go_repository( + name = "com_github_cncf_udpa_go", + importpath = "github.com/cncf/udpa/go", + sum = "h1:WBZRG4aNOuI15bLRrCgN8fCq8E5Xuty6jGbmSNEvSsU=", + version = "v0.0.0-20191209042840-269d4d468f6f", +) + +go_repository( + name = "com_github_containerd_cgroups", + build_file_proto_mode = "disable", + importpath = "github.com/containerd/cgroups", + sum = "h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY=", + version = "v0.0.0-20181219155423-39b18af02c41", +) + +go_repository( + name = "com_github_containerd_console", + importpath = "github.com/containerd/console", + sum = "h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc=", + version = "v0.0.0-20191206165004-02ecf6a7291e", +) + +go_repository( + name = "com_github_containerd_continuity", + importpath = "github.com/containerd/continuity", + sum = "h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=", + version = "v0.0.0-20200710164510-efbc4488d8fe", +) + +go_repository( + name = "com_github_containerd_fifo", + importpath = "github.com/containerd/fifo", + sum = "h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw=", + version = "v0.0.0-20191213151349-ff969a566b00", +) + +go_repository( + name = "com_github_containerd_go_runc", + importpath = "github.com/containerd/go-runc", + sum = "h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds=", + version = "v0.0.0-20200220073739-7016d3ce2328", +) + +go_repository( + name = "com_github_containerd_ttrpc", + importpath = "github.com/containerd/ttrpc", + sum = "h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc=", + version = "v0.0.0-20200121165050-0be804eadb15", +) + +go_repository( + name = "com_github_coreos_go_systemd", + importpath = "github.com/coreos/go-systemd", + sum = "h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU=", + version = "v0.0.0-20191104093116-d3cd4ed1dbcf", +) + +go_repository( + name = "com_github_docker_go_events", + importpath = "github.com/docker/go-events", + sum = "h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8=", + version = "v0.0.0-20190806004212-e31b211e4f1c", +) + +go_repository( + name = "com_github_dustin_go_humanize", + importpath = "github.com/dustin/go-humanize", + sum = "h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=", + version = "v0.0.0-20171111073723-bb3d318650d4", +) + +go_repository( + name = "com_github_envoyproxy_go_control_plane", + importpath = "github.com/envoyproxy/go-control-plane", + sum = "h1:rEvIZUSZ3fx39WIi3JkQqQBitGwpELBIYWeBVh6wn+E=", + version = "v0.9.4", +) + +go_repository( + name = "com_github_envoyproxy_protoc_gen_validate", + importpath = "github.com/envoyproxy/protoc-gen-validate", + sum = "h1:EQciDnbrYxy13PgWoY8AqoxGiPrpgBZ1R8UNe3ddc+A=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_fsnotify_fsnotify", + importpath = "github.com/fsnotify/fsnotify", + sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=", + version = "v1.4.7", +) + +go_repository( + name = "com_github_godbus_dbus", + importpath = "github.com/godbus/dbus", + sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=", + version = "v0.0.0-20190422162347-ade71ed3457e", +) + +go_repository( + name = "com_github_gogo_googleapis", + importpath = "github.com/gogo/googleapis", + sum = "h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=", + version = "v1.4.0", +) + +go_repository( + name = "com_github_gogo_protobuf", + importpath = "github.com/gogo/protobuf", + sum = "h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls=", + version = "v1.3.1", +) + +go_repository( + name = "com_github_golang_glog", + importpath = "github.com/golang/glog", + sum = "h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58=", + version = "v0.0.0-20160126235308-23def4e6c14b", +) + +go_repository( + name = "com_github_google_go_cmp", + importpath = "github.com/google/go-cmp", + sum = "h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_google_go_github_v28", + importpath = "github.com/google/go-github/v28", + sum = "h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU=", + version = "v28.1.2-0.20191108005307-e555eab49ce8", +) + +go_repository( + name = "com_github_google_go_querystring", + importpath = "github.com/google/go-querystring", + sum = "h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_golang_lru", + importpath = "github.com/hashicorp/golang-lru", + sum = "h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU=", + version = "v0.5.1", +) + +go_repository( + name = "com_github_hpcloud_tail", + importpath = "github.com/hpcloud/tail", + sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_inconshreveable_mousetrap", + importpath = "github.com/inconshreveable/mousetrap", + sum = "h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_kisielk_errcheck", + importpath = "github.com/kisielk/errcheck", + sum = "h1:reN85Pxc5larApoH1keMBiu2GWtPqXQ1nc9gx+jOU+E=", + version = "v1.2.0", +) + +go_repository( + name = "com_github_kisielk_gotool", + importpath = "github.com/kisielk/gotool", + sum = "h1:AV2c/EiW3KqPNT9ZKl07ehoAGi4C5/01Cfbblndcapg=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_microsoft_hcsshim", + importpath = "github.com/Microsoft/hcsshim", + sum = "h1:ZfF0+zZeYdzMIVMZHKtDKJvLHj76XCuVae/jNkjj0IA=", + version = "v0.8.6", +) + +go_repository( + name = "com_github_onsi_ginkgo", + importpath = "github.com/onsi/ginkgo", + sum = "h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=", + version = "v1.10.1", +) + +go_repository( + name = "com_github_onsi_gomega", + importpath = "github.com/onsi/gomega", + sum = "h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=", + version = "v1.7.0", +) + +go_repository( + name = "com_github_opencontainers_runc", + importpath = "github.com/opencontainers/runc", + sum = "h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=", + version = "v0.1.1", +) + +go_repository( + name = "com_github_opencontainers_runtime_spec", + importpath = "github.com/opencontainers/runtime-spec", + sum = "h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8=", + version = "v1.0.2-0.20181111125026-1722abf79c2f", +) + +go_repository( + name = "com_github_pborman_uuid", + importpath = "github.com/pborman/uuid", + sum = "h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g=", + version = "v1.2.0", +) + +go_repository( + name = "com_github_prometheus_client_model", + importpath = "github.com/prometheus/client_model", + sum = "h1:gQz4mCbXsO+nc9n1hCxHcGA3Zx3Eo+UHZoInFGUIXNM=", + version = "v0.0.0-20190812154241-14fe0d1b01d4", +) + +go_repository( + name = "com_github_prometheus_procfs", + importpath = "github.com/prometheus/procfs", + sum = "h1:Lo6mRUjdS99f3zxYOUalftWHUoOGaDRqFk1+j0Q57/I=", + version = "v0.0.0-20190522114515-bc1a522cf7b1", +) + +go_repository( + name = "com_github_spf13_cobra", + importpath = "github.com/spf13/cobra", + sum = "h1:GQkkv3XSnxhAMjdq2wLfEnptEVr+2BNvmHizILHn+d4=", + version = "v0.0.2-0.20171109065643-2da4a54c5cee", +) + +go_repository( + name = "com_github_spf13_pflag", + importpath = "github.com/spf13/pflag", + sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=", + version = "v1.0.1-0.20171106142849-4c012f6dcd95", +) + +go_repository( + name = "com_github_urfave_cli", + importpath = "github.com/urfave/cli", + sum = "h1:MCfT24H3f//U5+UCrZp1/riVO3B50BovxtDiNn0XKkk=", + version = "v0.0.0-20171014202726-7bc6a0acffa5", +) + +go_repository( + name = "com_github_yuin_goldmark", + importpath = "github.com/yuin/goldmark", + sum = "h1:5tjfNdR2ki3yYQ842+eX2sQHeiwpKJ0RnHO4IYOc4V8=", + version = "v1.1.32", +) + +go_repository( + name = "in_gopkg_airbrake_gobrake_v2", + importpath = "gopkg.in/airbrake/gobrake.v2", + sum = "h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=", + version = "v2.0.9", +) + +go_repository( + name = "in_gopkg_fsnotify_v1", + importpath = "gopkg.in/fsnotify.v1", + sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=", + version = "v1.4.7", +) + +go_repository( + name = "in_gopkg_gemnasium_logrus_airbrake_hook_v2", + importpath = "gopkg.in/gemnasium/logrus-airbrake-hook.v2", + sum = "h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0=", + version = "v2.1.2", +) + +go_repository( + name = "in_gopkg_tomb_v1", + importpath = "gopkg.in/tomb.v1", + sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=", + version = "v1.0.0-20141024135613-dd632973f1e7", +) + +go_repository( + name = "in_gopkg_yaml_v2", + importpath = "gopkg.in/yaml.v2", + sum = "h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=", + version = "v2.2.8", +) + +go_repository( + name = "org_bazil_fuse", + importpath = "bazil.org/fuse", + sum = "h1:SC+c6A1qTFstO9qmB86mPV2IpYme/2ZoEQ0hrP+wo+Q=", + version = "v0.0.0-20160811212531-371fbbdaa898", +) + +go_repository( + name = "org_golang_google_appengine", + importpath = "google.golang.org/appengine", + sum = "h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=", + version = "v1.6.5", +) + +go_repository( + name = "org_golang_google_genproto", + importpath = "google.golang.org/genproto", + sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", + version = "v0.0.0-20200117163144-32f20d992d24", +) + +go_repository( + name = "org_golang_google_protobuf", + importpath = "google.golang.org/protobuf", + sum = "h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=", + version = "v1.23.0", +) + +go_repository( + name = "org_golang_x_exp", + importpath = "golang.org/x/exp", + sum = "h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg=", + version = "v0.0.0-20191227195350-da58074b4299", +) + +go_repository( + name = "org_golang_x_lint", + importpath = "golang.org/x/lint", + sum = "h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE=", + version = "v0.0.0-20191125180803-fdd1cda4f05f", +) + +go_repository( + name = "tools_gotest", + importpath = "gotest.tools", + sum = "h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=", + version = "v2.2.0+incompatible", +) + +go_repository( + name = "com_github_burntsushi_xgb", + importpath = "github.com/BurntSushi/xgb", + sum = "h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc=", + version = "v0.0.0-20160522181843-27f122750802", +) + +go_repository( + name = "com_github_chzyer_logex", + importpath = "github.com/chzyer/logex", + sum = "h1:Swpa1K6QvQznwJRcfTfQJmTE72DqScAa40E+fbHEXEE=", + version = "v1.1.10", +) + +go_repository( + name = "com_github_chzyer_readline", + importpath = "github.com/chzyer/readline", + sum = "h1:fY5BOSpyZCqRo5OhCuC+XN+r/bBCmeuuJtjz+bCNIf8=", + version = "v0.0.0-20180603132655-2972be24d48e", +) + +go_repository( + name = "com_github_chzyer_test", + importpath = "github.com/chzyer/test", + sum = "h1:q763qf9huN11kDQavWsoZXJNW3xEE4JJyHa5Q25/sd8=", + version = "v0.0.0-20180213035817-a1ea475d72b1", +) + +go_repository( + name = "com_github_go_gl_glfw_v3_3_glfw", + importpath = "github.com/go-gl/glfw/v3.3/glfw", + sum = "h1:b+9H1GAsx5RsjvDFLoS5zkNBzIQMuVKUYQDmxU3N5XE=", + version = "v0.0.0-20191125211704-12ad95a8df72", +) + +go_repository( + name = "com_github_golang_groupcache", + importpath = "github.com/golang/groupcache", + sum = "h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=", + version = "v0.0.0-20191227052852-215e87163ea7", +) + +go_repository( + name = "com_github_google_martian", + importpath = "github.com/google/martian", + sum = "h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no=", + version = "v2.1.0+incompatible", +) + +go_repository( + name = "com_github_google_pprof", + importpath = "github.com/google/pprof", + sum = "h1:DLpL8pWq0v4JYoRpEhDfsJhhJyGKCcQM2WPW2TJs31c=", + version = "v0.0.0-20191218002539-d4f498aebedc", +) + +go_repository( + name = "com_github_google_renameio", + importpath = "github.com/google/renameio", + sum = "h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA=", + version = "v0.1.0", ) go_repository( @@ -407,46 +972,127 @@ go_repository( ) go_repository( - name = "io_opencensus_go", - importpath = "go.opencensus.io", - sum = "h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8=", - version = "v0.22.3", + name = "com_github_ianlancetaylor_demangle", + importpath = "github.com/ianlancetaylor/demangle", + sum = "h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c=", + version = "v0.0.0-20181102032728-5e5cf60278f6", ) go_repository( - name = "com_github_golang_groupcache", - importpath = "github.com/golang/groupcache", - sum = "h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY=", - version = "v0.0.0-20200121045136-8c9f03a8e57e", + name = "com_github_jstemmer_go_junit_report", + importpath = "github.com/jstemmer/go-junit-report", + sum = "h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o=", + version = "v0.9.1", ) -# System Call test dependencies. -http_archive( - name = "com_google_absl", - sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93", - strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e", - urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz", - ], +go_repository( + name = "com_github_rogpeppe_go_internal", + importpath = "github.com/rogpeppe/go-internal", + sum = "h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk=", + version = "v1.3.0", ) -http_archive( - name = "com_google_googletest", - sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912", - strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34", - urls = [ - "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz", - ], +go_repository( + name = "com_shuralyov_dmitri_gpu_mtl", + importpath = "dmitri.shuralyov.com/gpu/mtl", + sum = "h1:VpgP7xuJadIUuKccphEpTJnWhS2jkQyMt6Y7pJCD7fY=", + version = "v0.0.0-20190408044501-666a987793e9", ) -http_archive( - name = "com_google_benchmark", - sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a", - strip_prefix = "benchmark-1.5.0", - urls = [ - "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz", - "https://github.com/google/benchmark/archive/v1.5.0.tar.gz", - ], +go_repository( + name = "in_gopkg_errgo_v2", + importpath = "gopkg.in/errgo.v2", + sum = "h1:0vLT13EuvQ0hNvakwLuFZ/jYrLp5F3kcWHXdRggjCE8=", + version = "v2.1.0", +) + +go_repository( + name = "io_rsc_binaryregexp", + importpath = "rsc.io/binaryregexp", + sum = "h1:HfqmD5MEmC0zvwBuF187nq9mdnXjXsSivRiXN7SmRkE=", + version = "v0.2.0", +) + +go_repository( + name = "org_golang_google_api", + importpath = "google.golang.org/api", + sum = "h1:yzlyyDW/J0w8yNFJIhiAJy4kq74S+1DOLdawELNxFMA=", + version = "v0.15.0", +) + +go_repository( + name = "org_golang_x_image", + importpath = "golang.org/x/image", + sum = "h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=", + version = "v0.0.0-20190802002840-cff245a6509b", +) + +go_repository( + name = "org_golang_x_mobile", + importpath = "golang.org/x/mobile", + sum = "h1:4+4C/Iv2U4fMZBiMCc98MG1In4gJY5YRhtpDNeDeHWs=", + version = "v0.0.0-20190719004257-d2bd2a29d028", +) + +go_repository( + name = "com_github_containerd_typeurl", + importpath = "github.com/containerd/typeurl", + sum = "h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o=", + version = "v0.0.0-20200205145503-b45ef1f1f737", +) + +go_repository( + name = "com_github_vishvananda_netns", + importpath = "github.com/vishvananda/netns", + sum = "h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so=", + version = "v0.0.0-20200520041808-52d707b772fe", +) + +go_repository( + name = "com_google_cloud_go_bigquery", + importpath = "cloud.google.com/go/bigquery", + sum = "h1:hL+ycaJpVE9M7nLoiXb/Pn10ENE2u+oddxbD8uu0ZVU=", + version = "v1.0.1", +) + +go_repository( + name = "com_google_cloud_go_datastore", + importpath = "cloud.google.com/go/datastore", + sum = "h1:Kt+gOPPp2LEPWp8CSfxhsM8ik9CcyE/gYu+0r+RnZvM=", + version = "v1.0.0", +) + +go_repository( + name = "com_google_cloud_go_pubsub", + importpath = "cloud.google.com/go/pubsub", + sum = "h1:W9tAK3E57P75u0XLLR82LZyw8VpAnhmyTOxW9qzmyj8=", + version = "v1.0.1", +) + +go_repository( + name = "com_google_cloud_go_storage", + importpath = "cloud.google.com/go/storage", + sum = "h1:VV2nUM3wwLLGh9lSABFgZMjInyUbJeaRSE64WuAIQ+4=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_errwrap", + importpath = "github.com/hashicorp/errwrap", + sum = "h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_hashicorp_go_multierror", + importpath = "github.com/hashicorp/go-multierror", + sum = "h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_dpjacques_clockwork", + importpath = "github.com/dpjacques/clockwork", + sum = "h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U=", + version = "v0.1.1-0.20190114191937-d864eecc357b", ) diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD index 48c548d59..2090d957a 100644 --- a/benchmarks/harness/BUILD +++ b/benchmarks/harness/BUILD @@ -10,7 +10,6 @@ pkg_tar( srcs = [ "//tools/installers:head", "//tools/installers:master", - "//tools/installers:runsc", ], mode = "0755", ) diff --git a/benchmarks/tcp/tcp_proxy.go b/benchmarks/tcp/tcp_proxy.go index b3a4dbea3..4b7ca7a14 100644 --- a/benchmarks/tcp/tcp_proxy.go +++ b/benchmarks/tcp/tcp_proxy.go @@ -228,7 +228,7 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(*sack)); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) } diff --git a/benchmarks/workloads/httpd/BUILD b/benchmarks/workloads/httpd/BUILD index a70873065..83450d190 100644 --- a/benchmarks/workloads/httpd/BUILD +++ b/benchmarks/workloads/httpd/BUILD @@ -9,5 +9,6 @@ pkg_tar( name = "tar", srcs = [ "Dockerfile", + "apache2-tmpdir.conf", ], ) diff --git a/benchmarks/workloads/httpd/Dockerfile b/benchmarks/workloads/httpd/Dockerfile index 5259c8f4f..52a550678 100644 --- a/benchmarks/workloads/httpd/Dockerfile +++ b/benchmarks/workloads/httpd/Dockerfile @@ -6,16 +6,16 @@ RUN set -x \ apache2 \ && rm -rf /var/lib/apt/lists/* -# Link the htdoc directory to tmp. -RUN mkdir -p /usr/local/apache2/htdocs && \ - cd /usr/local/apache2 && ln -s /tmp htdocs - # Generate a bunch of relevant files. RUN mkdir -p /local && \ for size in 1 10 100 1000 1024 10240; do \ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ done +# Rewrite DocumentRoot to point to /tmp/html instead of the default path. +RUN sed -i 's/DocumentRoot.*\/var\/www\/html$/DocumentRoot \/tmp\/html/' /etc/apache2/sites-enabled/000-default.conf +COPY ./apache2-tmpdir.conf /etc/apache2/sites-enabled/apache2-tmpdir.conf + # Standard settings. ENV APACHE_RUN_DIR /tmp ENV APACHE_RUN_USER nobody @@ -24,4 +24,4 @@ ENV APACHE_LOG_DIR /tmp ENV APACHE_PID_FILE /tmp/apache.pid # Copy on start-up; serve everything from /tmp (including the configuration). -CMD ["sh", "-c", "cp -a /local/* /tmp && apache2 -c \"ServerName localhost\" -c \"DocumentRoot /tmp\" -X"] +CMD ["sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && apache2 -X"] diff --git a/benchmarks/workloads/httpd/apache2-tmpdir.conf b/benchmarks/workloads/httpd/apache2-tmpdir.conf new file mode 100644 index 000000000..e33f8d9bb --- /dev/null +++ b/benchmarks/workloads/httpd/apache2-tmpdir.conf @@ -0,0 +1,5 @@ +<Directory /tmp/html/> + Options Indexes FollowSymLinks + AllowOverride None + Require all granted +</Directory>
\ No newline at end of file diff --git a/benchmarks/workloads/tensorflow/Dockerfile b/benchmarks/workloads/tensorflow/Dockerfile index b5763e8ae..eefe6b3eb 100644 --- a/benchmarks/workloads/tensorflow/Dockerfile +++ b/benchmarks/workloads/tensorflow/Dockerfile @@ -3,8 +3,8 @@ FROM tensorflow/tensorflow:1.13.2 RUN apt-get update \ && apt-get install -y git RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git -RUN python -m pip install -U pip setuptools -RUN python -m pip install matplotlib +RUN python -m pip install --no-cache-dir -U pip setuptools +RUN python -m pip install --no-cache-dir matplotlib WORKDIR /TensorFlow-Examples/examples diff --git a/g3doc/BUILD b/g3doc/BUILD index dbbf96204..c315d38be 100644 --- a/g3doc/BUILD +++ b/g3doc/BUILD @@ -33,3 +33,12 @@ doc( subcategory = "Community", weight = "95", ) + +doc( + name = "style", + src = "style.md", + category = "Project", + permalink = "/community/style_guide/", + subcategory = "Community", + weight = "10", +) diff --git a/g3doc/README.md b/g3doc/README.md index 304a91493..22bfb15f7 100644 --- a/g3doc/README.md +++ b/g3doc/README.md @@ -117,9 +117,7 @@ for more information on filesystem bundles. `runsc` implements multiple commands that perform various functions such as starting, stopping, listing, and querying the status of containers. -### Sentry - -<a name="sentry"></a> <!-- For deep linking. --> +### Sentry {#sentry} The Sentry is the largest component of gVisor. It can be thought of as a application kernel. The Sentry implements all the kernel functionality needed by @@ -136,9 +134,7 @@ calls it makes. For example, the Sentry is not able to open files directly; file system operations that extend beyond the sandbox (not internal `/proc` files, pipes, etc) are sent to the Gofer, described below. -### Gofer - -<a name="gofer"></a> <!-- For deep linking. --> +### Gofer {#gofer} The Gofer is a standard host process which is started with each container and communicates with the Sentry via the [9P protocol][9p] over a socket or shared @@ -146,13 +142,13 @@ memory channel. The Sentry process is started in a restricted seccomp container without access to file system resources. The Gofer mediates all access to the these resources, providing an additional level of isolation. -### Application +### Application {#application} The application is a normal Linux binary provided to gVisor in an OCI runtime bundle. gVisor aims to provide an environment equivalent to Linux v4.4, so applications should be able to run unmodified. However, gVisor does not presently implement every system call, `/proc` file, or `/sys` file so some -incompatibilities may occur. See [Commpatibility](./user_guide/compatibility.md) +incompatibilities may occur. See [Compatibility](./user_guide/compatibility.md) for more information. [9p]: https://en.wikipedia.org/wiki/9P_(protocol) diff --git a/g3doc/style.md b/g3doc/style.md new file mode 100644 index 000000000..d10549fe9 --- /dev/null +++ b/g3doc/style.md @@ -0,0 +1,88 @@ +# Provisional style guide + +> These guidelines are new and may change. This note will be removed when +> consensus is reached. + +Not all existing code will comply with this style guide, but new code should. +Further, it is a goal to eventually update all existing code to be in +compliance. + +## All code + +### Early exit + +All code, unless it substantially increases the line count or complexity, should +use early exits from loops and functions where possible. + +## Go specific + +All Go code should comply with the [Go Code Review Comments][gostyle] and +[Effective Go][effective_go] guides, as well as the additional guidelines +described below. + +### Mutexes + +#### Naming + +Mutexes should be named mu or xxxMu. Mutexes as a general rule should not be +exported. Instead, export methods which use the mutexes to avoid leaky +abstractions. + +#### Location + +Mutexes should be sibling fields to the fields that they protect. Mutexes should +not be declared as global variables, instead use a struct (anonymous ok, but +naming conventions still apply). + +Mutexes should be ordered before the fields that they protect. + +#### Comments + +Mutexes should have a comment on their declaration explaining any ordering +requirements (or pointing to where this information can be found), if +applicable. There is no need for a comment explaining which fields are +protected. + +Each field or variable protected by a mutex should state as such in a comment on +the field or variable declaration. + +### Unused returns + +Unused returns should be explicitly ignored with underscores. If there is a +function which is commonly used without using its return(s), a wrapper function +should be declared which explicitly ignores the returns. That said, in many +cases, it may make sense for the wrapper to check the returns. + +### Formatting verbs + +Built-in types should use their associated verbs (e.g. %d for integral types), +but other types should use a %v variant, even if they implement fmt.Stringer. +The built-in `error` type should use %w when formatted with `fmt.Errorf`, but +only then. + +### Wrapping + +Comments should be wrapped at 80 columns with a 2 space tab size. + +Code does not need to be wrapped, but if wrapping would make it more readable, +it should be wrapped with each subcomponent of the thing being wrapped on its +own line. For example, if a struct is split between lines, each field should be +on its own line. + +#### Example + +```go +_ = exec.Cmd{ + Path: "/foo/bar", + Args: []string{"-baz"}, +} +``` + +## C++ specific + +C++ code should conform to the [Google C++ Style Guide][cppstyle] and the +guidelines described for tests. + +[cppstyle]: https://google.github.io/styleguide/cppguide.html +[gostyle]: https://github.com/golang/go/wiki/CodeReviewComments +[effective_go]: https://golang.org/doc/effective_go.html diff --git a/g3doc/user_guide/BUILD b/g3doc/user_guide/BUILD index 5568e1ba4..355dd49b3 100644 --- a/g3doc/user_guide/BUILD +++ b/g3doc/user_guide/BUILD @@ -33,7 +33,7 @@ doc( name = "FAQ", src = "FAQ.md", category = "User Guide", - permalink = "/docs/user_guide/FAQ/", + permalink = "/docs/user_guide/faq/", weight = "90", ) @@ -68,3 +68,12 @@ doc( permalink = "/docs/user_guide/platforms/", weight = "30", ) + +doc( + name = "runtimeclass", + src = "runtimeclass.md", + category = "User Guide", + permalink = "/docs/user_guide/runtimeclass/", + subcategory = "Advanced", + weight = "91", +) diff --git a/g3doc/user_guide/debugging.md b/g3doc/user_guide/debugging.md index 0525fd5c0..54fdce34f 100644 --- a/g3doc/user_guide/debugging.md +++ b/g3doc/user_guide/debugging.md @@ -129,3 +129,13 @@ go tool pprof -top /usr/local/bin/runsc /tmp/cpu.prof ``` [pprof]: https://github.com/google/pprof/blob/master/doc/README.md + +### Docker Proxy + +When forwarding a port to the container, Docker will likely route traffic +through the [docker-proxy][]. This proxy may make profiling noisy, so it can be +helpful to bypass it. Do so by sending traffic directly to the container IP and +port. e.g., if the `docker0` IP is `192.168.9.1`, the container IP is likely a +subsequent IP, such as `192.168.9.2`. + +[docker-proxy]: https://windsock.io/the-docker-proxy/ diff --git a/g3doc/user_guide/quick_start/oci.md b/g3doc/user_guide/quick_start/oci.md index 877169145..e7768946b 100644 --- a/g3doc/user_guide/quick_start/oci.md +++ b/g3doc/user_guide/quick_start/oci.md @@ -15,8 +15,8 @@ mkdir bundle cd bundle ``` -Create a root file system for the container. We will use the Docker hello-world -image as the basis for our container. +Create a root file system for the container. We will use the Docker +`hello-world` image as the basis for our container. ```bash mkdir rootfs @@ -24,12 +24,10 @@ docker export $(docker create hello-world) | tar -xf - -C rootfs ``` Next, create an specification file called `config.json` that contains our -container specification. We will update the default command it runs to `/hello` -in the `hello-world` container. +container specification. We tell the container to run the `/hello` program. ```bash -runsc spec -sed -i 's;"sh";"/hello";' config.json +runsc spec -- /hello ``` Finally run the container. diff --git a/g3doc/user_guide/runtimeclass.md b/g3doc/user_guide/runtimeclass.md new file mode 100644 index 000000000..2e2d997be --- /dev/null +++ b/g3doc/user_guide/runtimeclass.md @@ -0,0 +1,46 @@ +# RuntimeClass + +First, follow the appropriate installation instructions for your version of +containerd. + +* For 1.1 or lower, use `gvisor-containerd-shim`. +* For 1.2 or higher, use `containerd-shim-runsc-v1`. + +# Set up the Kubernetes RuntimeClass + +Creating the [RuntimeClass][runtimeclass] in Kubernetes is simple once the +runtime is available for containerd: + +```shell +cat <<EOF | kubectl apply -f - +apiVersion: node.k8s.io/v1beta1 +kind: RuntimeClass +metadata: + name: gvisor +handler: runsc +EOF +``` + +Pods can now be created using this RuntimeClass: + +```shell +cat <<EOF | kubectl apply -f - +apiVersion: v1 +kind: Pod +metadata: + name: nginx-gvisor +spec: + runtimeClassName: gvisor + containers: + - name: nginx + image: nginx +EOF +``` + +You can verify that the Pod is running via this RuntimeClass: + +```shell +kubectl get pod nginx-gvisor -o wide +``` + +[runtimeclass]: https://kubernetes.io/docs/concepts/containers/runtime-class/ diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md index ad6c9fa59..ce2fd09a8 100644 --- a/g3doc/user_guide/tutorials/cni.md +++ b/g3doc/user_guide/tutorials/cni.md @@ -128,12 +128,14 @@ sudo mkdir -p rootfs/var/www/html sudo sh -c 'echo "Hello World!" > rootfs/var/www/html/index.html' ``` -Next create the `config.json` specifying the network namespace. `sudo -/usr/local/bin/runsc spec sudo sed -i 's;"sh";"python", "-m", "http.server";' -config.json sudo sed -i "s;\"cwd\": \"/\";\"cwd\": \"/var/www/html\";" -config.json sudo sed -i "s;\"type\": \"network\";\"type\": -\"network\",\n\t\t\t\t\"path\": \"/var/run/netns/${CNI_CONTAINERID}\";" -config.json` +Next create the `config.json` specifying the network namespace. + +``` +sudo /usr/local/bin/runsc spec \ + --cwd /var/www/html \ + --netns /var/run/netns/${CNI_CONTAINERID} \ + -- python -m http.server +``` ## Run the Container @@ -2,19 +2,51 @@ module gvisor.dev/gvisor go 1.14 +replace github.com/Sirupsen/logrus => github.com/sirupsen/logrus v1.6.0 + require ( - github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 - github.com/golang/protobuf v1.3.1 - github.com/google/btree v1.0.0 - github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 - github.com/kr/pretty v0.2.0 // indirect - github.com/kr/pty v1.1.1 - github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 - github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 - github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e - github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 // indirect - golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 - golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 // indirect + github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 // indirect + github.com/Microsoft/hcsshim v0.8.6 // indirect + github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 // indirect + github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect + github.com/containerd/containerd v1.3.4 // indirect + github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect + github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect + github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect + github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect + github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 // indirect + github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf // indirect + github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible // indirect + github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect + github.com/docker/go-connections v0.3.0 // indirect + github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect + github.com/docker/go-units v0.4.0 // indirect + github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b // indirect + github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect + github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect + github.com/gogo/googleapis v1.4.0 // indirect + github.com/golang/protobuf v1.4.2 // indirect + github.com/google/go-cmp v0.5.0 // indirect + github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect + github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect + github.com/hashicorp/go-multierror v1.0.0 // indirect + github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect + github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.0.1 // indirect + github.com/opencontainers/runc v0.1.1 // indirect + github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect + github.com/pborman/uuid v1.2.0 // indirect + github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect + github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5 // indirect + github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect + github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe // indirect + go.uber.org/atomic v1.6.0 // indirect + go.uber.org/multierr v1.2.0 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect + golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a // indirect + google.golang.org/grpc v1.29.0 // indirect + gopkg.in/yaml.v2 v2.2.8 // indirect + gotest.tools v2.2.0+incompatible // indirect ) @@ -1,32 +1,387 @@ -github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422 h1:+FKjzBIdfBHYDvxCv+djmDJdes/AoDtg8gpcxowBlF8= -github.com/cenkalti/backoff v0.0.0-20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726 h1:Fvo/6MiAbwmQpsq5YFRo8O6TC40m9MK4Xh/oN07rIlo= +cloud.google.com/go v0.52.1-0.20200122224058-0482b626c726/go.mod h1:pXajvRH/6o3+F9jDHZWQ5PbGhn+o8w9qiu/CffaVdO4= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/Microsoft/go-winio v0.4.14 h1:+hMXMk01us9KgxGb7ftKQt2Xpf5hH/yky+TDA+qxleU= +github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= +github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5 h1:ygIc8M6trr62pF5DucadTWGdEB4mEyvzi0e2nbcmcyA= +github.com/Microsoft/go-winio v0.4.15-0.20190919025122-fc70bd9a86b5/go.mod h1:tTuCMEN+UleMWgg9dVx4Hu52b1bJo+59jBh3ajtinzw= +github.com/Microsoft/hcsshim v0.8.6/go.mod h1:Op3hHsoHPAvb6lceZHDtd9OkTew38wNoXnJs8iY7rUg= +github.com/Microsoft/hcsshim v0.8.7/go.mod h1:OHd7sQqRFrYd3RmSgbgji+ctCwkbq2wbEYNSzOYtcBQ= +github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= +github.com/Microsoft/hcsshim v0.8.9 h1:VrfodqvztU8YSOvygU+DN1BGaSGxmrNfqOv5oOuX2Bk= +github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= +github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= +github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 h1:5yg0k8gqOssNLsjjCtXIADoPbAtUtQZJfC8hQ4r2oFY= +github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41/go.mod h1:X9rLEHIqSf/wfK8NsPqxJmeZgW4pcfzdXITDrUSJ6uI= +github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f h1:tSNMc+rJDfmYntojat8lljbt1mgKNpTxUZJsSzJ9Y1s= +github.com/containerd/cgroups v0.0.0-20190919134610-bf292b21730f/go.mod h1:OApqhQ4XNSNC13gXIwDjhOQxjWa/NxkwZXJ1EvqT0ko= +github.com/containerd/console v0.0.0-20180822173158-c12b1e7919c1/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= +github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e h1:GdiIYd8ZDOrT++e1NjhSD4rGt9zaJukHm4rt5F4mRQc= +github.com/containerd/console v0.0.0-20191206165004-02ecf6a7291e/go.mod h1:8Pf4gM6VEbTNRIT26AyyU7hxdQU3MvAvxVI0sc00XBE= +github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= +github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI= +github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= +github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= +github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k= +github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe/go.mod h1:cECdGN1O8G9bgKTlLhuPJimka6Xb/Gg7vYzCTNVxhvo= +github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI= +github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw= +github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0= +github.com/containerd/go-runc v0.0.0-20180907222934-5a6d9f37cfa3/go.mod h1:IV7qH3hrUgRmyYrtgEeGWJfWbgcHL9CSRruz2Vqcph0= +github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 h1:PRTagVMbJcCezLcHXe8UJvR1oBzp2lG3CEumeFOLOds= +github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328/go.mod h1:PpyHrqVs8FTi9vpyHwPwiNEGaACDxT/N/pLcvMSRA9g= +github.com/containerd/ttrpc v0.0.0-20190828154514-0e0f228740de/go.mod h1:PvCDdDGpgqzQIzDW1TphrGLssLDZp2GuS+X5DkEJB8o= +github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 h1:+jgiLE5QylzgADj0Yldb4id1NQNRrDOROj7KDvY9PEc= +github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15/go.mod h1:UAxOpgT9ziI0gJrmKvgcZivgxOp8iFPSk8httJEt98Y= +github.com/containerd/typeurl v0.0.0-20180627222232-a93fcdb778cd/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= +github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737 h1:HovfQDS/K3Mr7eyS0QJLxE1CbVUhjZCl6g3OhFJgP1o= +github.com/containerd/typeurl v0.0.0-20200205145503-b45ef1f1f737/go.mod h1:TB1hUtrpaiO88KEK56ijojHS1+NeF0izUACaJW2mdXg= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= +github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= +github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w= +github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.3.0 h1:3lOnM9cSzgGwx8VfK/NGOW5fLQ0GjIlCkaktF+n1M6o= +github.com/docker/go-connections v0.3.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ+oDZB4KHQFypsfjYlq/C4rfL7D3g8= +github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= +github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= +github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b h1:7krODee+eIlZYoLiEDmP1kLFNCvd0bQ0eEXOympdN6U= +github.com/dpjacques/clockwork v0.1.1-0.20190114191937-d864eecc357b/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= +github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= +github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= -github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= +github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= +github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= +github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8 h1:GZGUPQiZfYrd9uOqyqwbQcHPkz/EZJVkZB1MkaO9UBI= -github.com/google/subcommands v0.0.0-20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= +github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= +github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM= +github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= +github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= +github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= +github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= +github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78 h1:d9F+LNYwMyi3BDN4GzZdaSiq4otb8duVEWyZjeUtOQI= -github.com/opencontainers/runtime-spec v0.1.2-0.20171211145439-b2d941ef6a78/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= +github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= +github.com/opencontainers/go-digest v1.0.0-rc1/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI= +github.com/opencontainers/image-spec v1.0.1/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= +github.com/opencontainers/runc v0.0.0-20190115041553-12f6a991201f/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= +github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y= +github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= +github.com/opencontainers/runtime-spec v0.1.2-0.20190507144316-5b71a03e2700/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.0.1/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f h1:Pyp2f/uuhJIcUgnIeZaAbwOcyNz8TBlEe6mPpC8kXq8= +github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= -github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e h1:/Tdc23Arz1OtdIsBY2utWepGRQ9fEAJlhkdoLzWMK8Q= -github.com/vishvananda/netlink v1.0.1-0.20190318003149-adb577d4a45e/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= -github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936 h1:J9gO8RJCAFlln1jsvRba/CWVUnMHwObklfxxjErl1uk= -github.com/vishvananda/netns v0.0.0-20171111001504-be1fbeda1936/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= -golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= +github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= +github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 h1:7SWt9pGCMaw+N1ZhRsaLKaYNviFhxambdoaoYlDqz1w= +github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk= +github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe h1:mjAZxE1nh8yvuwhGHpdDqdhtNu2dgbpk93TwoXuk5so= +github.com/vishvananda/netns v0.0.0-20200520041808-52d707b772fe/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= +go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299 h1:zQpM52jfKHG6II1ISZY1ZcpygvuSFZpLwfluuF89XOg= +golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344 h1:vGXIOMxbNfDTk/aXCmfdLgkrSV+Z2tcbze+pEc3v5W4= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= +golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190514135907-3a4b5fb9f71f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200523222454-059865788121 h1:rITEj+UZHYC927n8GT97eC3zrpzXdb/voyeOuVKS46o= +golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a h1:YAl/dx/kLsMMIWGqfhFHW9ckqGhmq7Ki0dfoKAgvFTE= +golang.org/x/tools v0.0.0-20200707200213-416e8f4faf8a/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.15.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= +google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/images/Makefile b/images/Makefile index 1485607bd..9de359a28 100644 --- a/images/Makefile +++ b/images/Makefile @@ -34,8 +34,15 @@ list-all-images: @for image in $(ALL_IMAGES); do echo $${image}; done .PHONY: list-build-images +# Handy wrapper to allow load-all-images, push-all-images, etc. %-all-images: @$(MAKE) $(patsubst %,$*-%,$(ALL_IMAGES)) +load-all-images: + @$(MAKE) $(patsubst %,load-%,$(ALL_IMAGES)) + +# Handy wrapper to load specified "groups", e.g. load-basic-images, etc. +load-%-images: + @$(MAKE) $(patsubst %,load-%,$(subst /,_,$(subst ./,,$(shell find ./$* -name Dockerfile -exec dirname {} \;)))) # tag is a function that returns the tag name, given an image. # diff --git a/images/README.md b/images/README.md index d2efb5db4..63ce46277 100644 --- a/images/README.md +++ b/images/README.md @@ -59,3 +59,12 @@ project. The continuous integration system can either take fine-grained dependencies on individual `push` targets, or ensure all images are up-to-date with a single `push-all-images` invocation. + +## Multi-Arch images + +By default, the image is built for host architecture. Cross-building can be +achieved by specifying `ARCH` variable to make. For example: + +``` +$ make ARCH=aarch64 rebuild-default +``` diff --git a/images/basic/hostoverlaytest/Dockerfile b/images/basic/hostoverlaytest/Dockerfile new file mode 100644 index 000000000..6cef1a542 --- /dev/null +++ b/images/basic/hostoverlaytest/Dockerfile @@ -0,0 +1,8 @@ +FROM ubuntu:bionic + +WORKDIR /root +COPY . . + +RUN apt-get update && apt-get install -y gcc +RUN gcc -O2 -o test_copy_up test_copy_up.c +RUN gcc -O2 -o test_rewinddir test_rewinddir.c diff --git a/images/hostoverlaytest/testfile.txt b/images/basic/hostoverlaytest/copy_up_testfile.txt index e4188c841..e4188c841 100644 --- a/images/hostoverlaytest/testfile.txt +++ b/images/basic/hostoverlaytest/copy_up_testfile.txt diff --git a/images/hostoverlaytest/test.c b/images/basic/hostoverlaytest/test_copy_up.c index 088f90746..010b261dc 100644 --- a/images/hostoverlaytest/test.c +++ b/images/basic/hostoverlaytest/test_copy_up.c @@ -6,7 +6,7 @@ #include <unistd.h> int main(int argc, char** argv) { - const char kTestFilePath[] = "testfile.txt"; + const char kTestFilePath[] = "copy_up_testfile.txt"; const char kOldFileData[] = "old data\n"; const char kNewFileData[] = "new data\n"; const size_t kPageSize = sysconf(_SC_PAGE_SIZE); diff --git a/images/basic/hostoverlaytest/test_rewinddir.c b/images/basic/hostoverlaytest/test_rewinddir.c new file mode 100644 index 000000000..f1a4085e1 --- /dev/null +++ b/images/basic/hostoverlaytest/test_rewinddir.c @@ -0,0 +1,78 @@ +#include <dirent.h> +#include <err.h> +#include <errno.h> +#include <stdlib.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> + +int main(int argc, char** argv) { + const char kDirPath[] = "rewinddir_test_dir"; + const char kFileBasename[] = "rewinddir_test_file"; + + // Create the test directory. + if (mkdir(kDirPath, 0755) < 0) { + err(1, "mkdir(%s)", kDirPath); + } + + // The test directory should initially be empty. + DIR* dir = opendir(kDirPath); + if (!dir) { + err(1, "opendir(%s)", kDirPath); + } + int failed = 0; + while (1) { + errno = 0; + struct dirent* d = readdir(dir); + if (!d) { + if (errno != 0) { + err(1, "readdir"); + } + break; + } + if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) { + warnx("unexpected file %s in new directory", d->d_name); + failed = 1; + } + } + + // Create a file in the test directory. + char* file_path = malloc(strlen(kDirPath) + 1 + strlen(kFileBasename)); + if (!file_path) { + errx(1, "malloc"); + } + strcpy(file_path, kDirPath); + file_path[strlen(kDirPath)] = '/'; + strcpy(file_path + strlen(kDirPath) + 1, kFileBasename); + if (mknod(file_path, 0644, 0) < 0) { + err(1, "mknod(%s)", file_path); + } + + // After rewinddir(), re-reading the directory stream should yield the new + // file. + rewinddir(dir); + size_t found_file = 0; + while (1) { + errno = 0; + struct dirent* d = readdir(dir); + if (!d) { + if (errno != 0) { + err(1, "readdir"); + } + break; + } + if (strcmp(d->d_name, kFileBasename) == 0) { + found_file++; + } else if (strcmp(d->d_name, ".") != 0 && strcmp(d->d_name, "..") != 0) { + warnx("unexpected file %s in new directory", d->d_name); + failed = 1; + } + } + if (found_file != 1) { + warnx("readdir returned file %s %zu times, wanted 1", kFileBasename, + found_file); + failed = 1; + } + + return failed; +} diff --git a/images/tmpfile/Dockerfile b/images/basic/tmpfile/Dockerfile index e3816c8cb..e3816c8cb 100644 --- a/images/tmpfile/Dockerfile +++ b/images/basic/tmpfile/Dockerfile diff --git a/images/benchmarks/ab/Dockerfile b/images/benchmarks/ab/Dockerfile new file mode 100644 index 000000000..10544639b --- /dev/null +++ b/images/benchmarks/ab/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2-utils \ + && rm -rf /var/lib/apt/lists/* diff --git a/images/benchmarks/absl/Dockerfile b/images/benchmarks/absl/Dockerfile new file mode 100644 index 000000000..b0dd97695 --- /dev/null +++ b/images/benchmarks/absl/Dockerfile @@ -0,0 +1,21 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + git \ + pkg-config \ + zip \ + g++ \ + zlib1g-dev \ + unzip \ + python3 \ + && rm -rf /var/lib/apt/lists/* +RUN wget https://github.com/bazelbuild/bazel/releases/download/0.27.0/bazel-0.27.0-installer-linux-x86_64.sh +RUN chmod +x bazel-0.27.0-installer-linux-x86_64.sh +RUN ./bazel-0.27.0-installer-linux-x86_64.sh + +RUN mkdir abseil-cpp && cd abseil-cpp \ + && git init && git remote add origin https://github.com/abseil/abseil-cpp.git \ + && git fetch --depth 1 origin 43ef2148c0936ebf7cb4be6b19927a9d9d145b8f && git checkout FETCH_HEAD diff --git a/images/benchmarks/ffmpeg/Dockerfile b/images/benchmarks/ffmpeg/Dockerfile new file mode 100644 index 000000000..7108df64f --- /dev/null +++ b/images/benchmarks/ffmpeg/Dockerfile @@ -0,0 +1,9 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* +WORKDIR /media +ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4 diff --git a/images/benchmarks/fio/Dockerfile b/images/benchmarks/fio/Dockerfile new file mode 100644 index 000000000..9531df7fa --- /dev/null +++ b/images/benchmarks/fio/Dockerfile @@ -0,0 +1,7 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + fio \ + && rm -rf /var/lib/apt/lists/* diff --git a/images/benchmarks/hey/Dockerfile b/images/benchmarks/hey/Dockerfile new file mode 100644 index 000000000..f586978b6 --- /dev/null +++ b/images/benchmarks/hey/Dockerfile @@ -0,0 +1,12 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + && rm -rf /var/lib/apt/lists/* + +RUN wget https://storage.googleapis.com/hey-release/hey_linux_amd64 \ + && chmod 777 hey_linux_amd64 \ + && cp hey_linux_amd64 /bin/hey \ + && rm hey_linux_amd64 diff --git a/images/benchmarks/httpd/Dockerfile b/images/benchmarks/httpd/Dockerfile new file mode 100644 index 000000000..b72406012 --- /dev/null +++ b/images/benchmarks/httpd/Dockerfile @@ -0,0 +1,17 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + apache2 \ + && rm -rf /var/lib/apt/lists/* + +# Generate a bunch of relevant files. +RUN mkdir -p /local && \ + for size in 1 10 100 1000 1024 10240; do \ + dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ + done + +# Rewrite DocumentRoot to point to /tmp/html instead of the default path. +RUN sed -i 's/DocumentRoot.*\/var\/www\/html$/DocumentRoot \/tmp\/html/' /etc/apache2/sites-enabled/000-default.conf +COPY ./apache2-tmpdir.conf /etc/apache2/sites-enabled/apache2-tmpdir.conf diff --git a/images/benchmarks/httpd/apache2-tmpdir.conf b/images/benchmarks/httpd/apache2-tmpdir.conf new file mode 100644 index 000000000..e33f8d9bb --- /dev/null +++ b/images/benchmarks/httpd/apache2-tmpdir.conf @@ -0,0 +1,5 @@ +<Directory /tmp/html/> + Options Indexes FollowSymLinks + AllowOverride None + Require all granted +</Directory>
\ No newline at end of file diff --git a/images/benchmarks/iperf/Dockerfile b/images/benchmarks/iperf/Dockerfile new file mode 100644 index 000000000..4cbfd0d70 --- /dev/null +++ b/images/benchmarks/iperf/Dockerfile @@ -0,0 +1,8 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + iperf \ + && rm -rf /var/lib/apt/lists/* + diff --git a/images/benchmarks/node/Dockerfile b/images/benchmarks/node/Dockerfile new file mode 100644 index 000000000..bf45650a0 --- /dev/null +++ b/images/benchmarks/node/Dockerfile @@ -0,0 +1 @@ +FROM node:onbuild diff --git a/images/benchmarks/node/index.hbs b/images/benchmarks/node/index.hbs new file mode 100644 index 000000000..03feceb75 --- /dev/null +++ b/images/benchmarks/node/index.hbs @@ -0,0 +1,8 @@ +<!DOCTYPE html> +<html> +<body> + {{#each text}} + <p>{{this}}</p> + {{/each}} +</body> +</html> diff --git a/images/benchmarks/node/index.js b/images/benchmarks/node/index.js new file mode 100644 index 000000000..831015d18 --- /dev/null +++ b/images/benchmarks/node/index.js @@ -0,0 +1,42 @@ +const app = require('express')(); +const path = require('path'); +const redis = require('redis'); +const srs = require('secure-random-string'); + +// The hostname is the first argument. +const host_name = process.argv[2]; + +var client = redis.createClient({host: host_name, detect_buffers: true}); + +app.set('views', __dirname); +app.set('view engine', 'hbs'); + +app.get('/', (req, res) => { + var tmp = []; + /* Pull four random keys from the redis server. */ + for (i = 0; i < 4; i++) { + client.get(Math.floor(Math.random() * (100)), function(err, reply) { + tmp.push(reply.toString()); + }); + } + res.render('index', {text: tmp}); +}); + +/** + * Securely generate a random string. + * @param {number} len + * @return {string} + */ +function randomBody(len) { + return srs({alphanumeric: true, length: len}); +} + +/** Mutates one hundred keys randomly. */ +function generateText() { + for (i = 0; i < 100; i++) { + client.set(i, randomBody(1024)); + } +} + +generateText(); +app.listen(8080); diff --git a/images/benchmarks/node/package-lock.json b/images/benchmarks/node/package-lock.json new file mode 100644 index 000000000..580e68aa5 --- /dev/null +++ b/images/benchmarks/node/package-lock.json @@ -0,0 +1,486 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "lockfileVersion": 1, + "requires": true, + "dependencies": { + "accepts": { + "version": "1.3.5", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-1.3.5.tgz", + "integrity": "sha1-63d99gEXI6OxTopywIBcjoZ0a9I=", + "requires": { + "mime-types": "~2.1.18", + "negotiator": "0.6.1" + } + }, + "array-flatten": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/array-flatten/-/array-flatten-1.1.1.tgz", + "integrity": "sha1-ml9pkFGx5wczKPKgCJaLZOopVdI=" + }, + "async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/async/-/async-2.6.2.tgz", + "integrity": "sha512-H1qVYh1MYhEEFLsP97cVKqCGo7KfCyTt6uEWqsTBr9SO84oK9Uwbyd/yCW+6rKJLHksBNUVWZDAjfS+Ccx0Bbg==", + "requires": { + "lodash": "^4.17.11" + } + }, + "body-parser": { + "version": "1.18.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.18.3.tgz", + "integrity": "sha1-WykhmP/dVTs6DyDe0FkrlWlVyLQ=", + "requires": { + "bytes": "3.0.0", + "content-type": "~1.0.4", + "debug": "2.6.9", + "depd": "~1.1.2", + "http-errors": "~1.6.3", + "iconv-lite": "0.4.23", + "on-finished": "~2.3.0", + "qs": "6.5.2", + "raw-body": "2.3.3", + "type-is": "~1.6.16" + } + }, + "bytes": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.0.0.tgz", + "integrity": "sha1-0ygVQE1olpn4Wk6k+odV3ROpYEg=" + }, + "commander": { + "version": "2.20.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.0.tgz", + "integrity": "sha512-7j2y+40w61zy6YC2iRNpUe/NwhNyoXrYpHMrSunaMG64nRnaf96zO/KMQR4OyN/UnE5KLyEBnKHd4aG3rskjpQ==", + "optional": true + }, + "content-disposition": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-0.5.2.tgz", + "integrity": "sha1-DPaLud318r55YcOoUXjLhdunjLQ=" + }, + "content-type": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.4.tgz", + "integrity": "sha512-hIP3EEPs8tB9AT1L+NUqtwOAps4mk2Zob89MWXMHjHWg9milF/j4osnnQLXBCBFBk/tvIG/tUc9mOUJiPBhPXA==" + }, + "cookie": { + "version": "0.3.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.3.1.tgz", + "integrity": "sha1-5+Ch+e9DtMi6klxcWpboBtFoc7s=" + }, + "cookie-signature": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.0.6.tgz", + "integrity": "sha1-4wOogrNCzD7oylE6eZmXNNqzriw=" + }, + "debug": { + "version": "2.6.9", + "resolved": "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz", + "integrity": "sha512-bC7ElrdJaJnPbAP+1EotYvqZsb3ecl5wi6Bfi6BJTUcNowp6cvspg0jXznRTKDjm/E7AdgFBVeAPVMNcKGsHMA==", + "requires": { + "ms": "2.0.0" + } + }, + "depd": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/depd/-/depd-1.1.2.tgz", + "integrity": "sha1-m81S4UwJd2PnSbJ0xDRu0uVgtak=" + }, + "destroy": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/destroy/-/destroy-1.0.4.tgz", + "integrity": "sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=" + }, + "double-ended-queue": { + "version": "2.1.0-0", + "resolved": "https://registry.npmjs.org/double-ended-queue/-/double-ended-queue-2.1.0-0.tgz", + "integrity": "sha1-ED01J/0xUo9AGIEwyEHv3XgmTlw=" + }, + "ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha1-WQxhFWsK4vTwJVcyoViyZrxWsh0=" + }, + "encodeurl": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-1.0.2.tgz", + "integrity": "sha1-rT/0yG7C0CkyL1oCw6mmBslbP1k=" + }, + "escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha1-Aljq5NPQwJdN4cFpGI7wBR0dGYg=" + }, + "etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha1-Qa4u62XvpiJorr/qg6x9eSmbCIc=" + }, + "express": { + "version": "4.16.4", + "resolved": "https://registry.npmjs.org/express/-/express-4.16.4.tgz", + "integrity": "sha512-j12Uuyb4FMrd/qQAm6uCHAkPtO8FDTRJZBDd5D2KOL2eLaz1yUNdUB/NOIyq0iU4q4cFarsUCrnFDPBcnksuOg==", + "requires": { + "accepts": "~1.3.5", + "array-flatten": "1.1.1", + "body-parser": "1.18.3", + "content-disposition": "0.5.2", + "content-type": "~1.0.4", + "cookie": "0.3.1", + "cookie-signature": "1.0.6", + "debug": "2.6.9", + "depd": "~1.1.2", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "finalhandler": "1.1.1", + "fresh": "0.5.2", + "merge-descriptors": "1.0.1", + "methods": "~1.1.2", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "path-to-regexp": "0.1.7", + "proxy-addr": "~2.0.4", + "qs": "6.5.2", + "range-parser": "~1.2.0", + "safe-buffer": "5.1.2", + "send": "0.16.2", + "serve-static": "1.13.2", + "setprototypeof": "1.1.0", + "statuses": "~1.4.0", + "type-is": "~1.6.16", + "utils-merge": "1.0.1", + "vary": "~1.1.2" + } + }, + "finalhandler": { + "version": "1.1.1", + "resolved": "http://registry.npmjs.org/finalhandler/-/finalhandler-1.1.1.tgz", + "integrity": "sha512-Y1GUDo39ez4aHAw7MysnUD5JzYX+WaIj8I57kO3aEPT1fFRL4sr7mjei97FgnwhAyyzRYmQZaTHb2+9uZ1dPtg==", + "requires": { + "debug": "2.6.9", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "on-finished": "~2.3.0", + "parseurl": "~1.3.2", + "statuses": "~1.4.0", + "unpipe": "~1.0.0" + } + }, + "foreachasync": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/foreachasync/-/foreachasync-3.0.0.tgz", + "integrity": "sha1-VQKYfchxS+M5IJfzLgBxyd7gfPY=" + }, + "forwarded": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.1.2.tgz", + "integrity": "sha1-mMI9qxF1ZXuMBXPozszZGw/xjIQ=" + }, + "fresh": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-0.5.2.tgz", + "integrity": "sha1-PYyt2Q2XZWn6g1qx+OSyOhBWBac=" + }, + "handlebars": { + "version": "4.0.14", + "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.0.14.tgz", + "integrity": "sha512-E7tDoyAA8ilZIV3xDJgl18sX3M8xB9/fMw8+mfW4msLW8jlX97bAnWgT3pmaNXuvzIEgSBMnAHfuXsB2hdzfow==", + "requires": { + "async": "^2.5.0", + "optimist": "^0.6.1", + "source-map": "^0.6.1", + "uglify-js": "^3.1.4" + } + }, + "hbs": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/hbs/-/hbs-4.0.4.tgz", + "integrity": "sha512-esVlyV/V59mKkwFai5YmPRSNIWZzhqL5YMN0++ueMxyK1cCfPa5f6JiHtapPKAIVAhQR6rpGxow0troav9WMEg==", + "requires": { + "handlebars": "4.0.14", + "walk": "2.3.9" + } + }, + "http-errors": { + "version": "1.6.3", + "resolved": "http://registry.npmjs.org/http-errors/-/http-errors-1.6.3.tgz", + "integrity": "sha1-i1VoC7S+KDoLW/TqLjhYC+HZMg0=", + "requires": { + "depd": "~1.1.2", + "inherits": "2.0.3", + "setprototypeof": "1.1.0", + "statuses": ">= 1.4.0 < 2" + } + }, + "iconv-lite": { + "version": "0.4.23", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.23.tgz", + "integrity": "sha512-neyTUVFtahjf0mB3dZT77u+8O0QB89jFdnBkd5P1JgYPbPaia3gXXOVL2fq8VyU2gMMD7SaN7QukTB/pmXYvDA==", + "requires": { + "safer-buffer": ">= 2.1.2 < 3" + } + }, + "inherits": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.3.tgz", + "integrity": "sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=" + }, + "ipaddr.js": { + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.8.0.tgz", + "integrity": "sha1-6qM9bd16zo9/b+DJygRA5wZzix4=" + }, + "lodash": { + "version": "4.17.15", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.15.tgz", + "integrity": "sha512-8xOcRHvCjnocdS5cpwXQXVzmmh5e5+saE2QGoeQmbKmRS6J3VQppPOIt0MnmE+4xlZoumy0GPG0D0MVIQbNA1A==" + }, + "media-typer": { + "version": "0.3.0", + "resolved": "http://registry.npmjs.org/media-typer/-/media-typer-0.3.0.tgz", + "integrity": "sha1-hxDXrwqmJvj/+hzgAWhUUmMlV0g=" + }, + "merge-descriptors": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", + "integrity": "sha1-sAqqVW3YtEVoFQ7J0blT8/kMu2E=" + }, + "methods": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/methods/-/methods-1.1.2.tgz", + "integrity": "sha1-VSmk1nZUE07cxSZmVoNbD4Ua/O4=" + }, + "mime": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/mime/-/mime-1.4.1.tgz", + "integrity": "sha512-KI1+qOZu5DcW6wayYHSzR/tXKCDC5Om4s1z2QJjDULzLcmf3DvzS7oluY4HCTrc+9FiKmWUgeNLg7W3uIQvxtQ==" + }, + "mime-db": { + "version": "1.37.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.37.0.tgz", + "integrity": "sha512-R3C4db6bgQhlIhPU48fUtdVmKnflq+hRdad7IyKhtFj06VPNVdk2RhiYL3UjQIlso8L+YxAtFkobT0VK+S/ybg==" + }, + "mime-types": { + "version": "2.1.21", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.21.tgz", + "integrity": "sha512-3iL6DbwpyLzjR3xHSFNFeb9Nz/M8WDkX33t1GFQnFOllWk8pOrh/LSrB5OXlnlW5P9LH73X6loW/eogc+F5lJg==", + "requires": { + "mime-db": "~1.37.0" + } + }, + "minimist": { + "version": "0.0.10", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", + "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" + }, + "ms": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", + "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" + }, + "negotiator": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.1.tgz", + "integrity": "sha1-KzJxhOiZIQEXeyhWP7XnECrNDKk=" + }, + "on-finished": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.3.0.tgz", + "integrity": "sha1-IPEzZIGwg811M3mSoWlxqi2QaUc=", + "requires": { + "ee-first": "1.1.1" + } + }, + "optimist": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/optimist/-/optimist-0.6.1.tgz", + "integrity": "sha1-2j6nRob6IaGaERwybpDrFaAZZoY=", + "requires": { + "minimist": "~0.0.1", + "wordwrap": "~0.0.2" + } + }, + "parseurl": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.2.tgz", + "integrity": "sha1-/CidTtiZMRlGDBViUyYs3I3mW/M=" + }, + "path-to-regexp": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", + "integrity": "sha1-32BBeABfUi8V60SQ5yR6G/qmf4w=" + }, + "proxy-addr": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.4.tgz", + "integrity": "sha512-5erio2h9jp5CHGwcybmxmVqHmnCBZeewlfJ0pex+UW7Qny7OOZXTtH56TGNyBizkgiOwhJtMKrVzDTeKcySZwA==", + "requires": { + "forwarded": "~0.1.2", + "ipaddr.js": "1.8.0" + } + }, + "qs": { + "version": "6.5.2", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.5.2.tgz", + "integrity": "sha512-N5ZAX4/LxJmF+7wN74pUD6qAh9/wnvdQcjq9TZjevvXzSUo7bfmw91saqMjzGS2xq91/odN2dW/WOl7qQHNDGA==" + }, + "range-parser": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.0.tgz", + "integrity": "sha1-9JvmtIeJTdxA3MlKMi9hEJLgDV4=" + }, + "raw-body": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.3.3.tgz", + "integrity": "sha512-9esiElv1BrZoI3rCDuOuKCBRbuApGGaDPQfjSflGxdy4oyzqghxu6klEkkVIvBje+FF0BX9coEv8KqW6X/7njw==", + "requires": { + "bytes": "3.0.0", + "http-errors": "1.6.3", + "iconv-lite": "0.4.23", + "unpipe": "1.0.0" + } + }, + "redis": { + "version": "2.8.0", + "resolved": "https://registry.npmjs.org/redis/-/redis-2.8.0.tgz", + "integrity": "sha512-M1OkonEQwtRmZv4tEWF2VgpG0JWJ8Fv1PhlgT5+B+uNq2cA3Rt1Yt/ryoR+vQNOQcIEgdCdfH0jr3bDpihAw1A==", + "requires": { + "double-ended-queue": "^2.1.0-0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0" + }, + "dependencies": { + "redis-commands": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.4.0.tgz", + "integrity": "sha512-cu8EF+MtkwI4DLIT0x9P8qNTLFhQD4jLfxLR0cCNkeGzs87FN6879JOJwNQR/1zD7aSYNbU0hgsV9zGY71Itvw==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + } + } + }, + "redis-commands": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/redis-commands/-/redis-commands-1.5.0.tgz", + "integrity": "sha512-6KxamqpZ468MeQC3bkWmCB1fp56XL64D4Kf0zJSwDZbVLLm7KFkoIcHrgRvQ+sk8dnhySs7+yBg94yIkAK7aJg==" + }, + "redis-parser": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/redis-parser/-/redis-parser-2.6.0.tgz", + "integrity": "sha1-Uu0J2srBCPGmMcB+m2mUHnoZUEs=" + }, + "safe-buffer": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.2.tgz", + "integrity": "sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==" + }, + "safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==" + }, + "secure-random-string": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/secure-random-string/-/secure-random-string-1.1.0.tgz", + "integrity": "sha512-V/h8jqoz58zklNGybVhP++cWrxEPXlLM/6BeJ4e0a8zlb4BsbYRzFs16snrxByPa5LUxCVTD3M6EYIVIHR1fAg==" + }, + "send": { + "version": "0.16.2", + "resolved": "https://registry.npmjs.org/send/-/send-0.16.2.tgz", + "integrity": "sha512-E64YFPUssFHEFBvpbbjr44NCLtI1AohxQ8ZSiJjQLskAdKuriYEP6VyGEsRDH8ScozGpkaX1BGvhanqCwkcEZw==", + "requires": { + "debug": "2.6.9", + "depd": "~1.1.2", + "destroy": "~1.0.4", + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "etag": "~1.8.1", + "fresh": "0.5.2", + "http-errors": "~1.6.2", + "mime": "1.4.1", + "ms": "2.0.0", + "on-finished": "~2.3.0", + "range-parser": "~1.2.0", + "statuses": "~1.4.0" + } + }, + "serve-static": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.13.2.tgz", + "integrity": "sha512-p/tdJrO4U387R9oMjb1oj7qSMaMfmOyd4j9hOFoxZe2baQszgHcSWjuya/CiT5kgZZKRudHNOA0pYXOl8rQ5nw==", + "requires": { + "encodeurl": "~1.0.2", + "escape-html": "~1.0.3", + "parseurl": "~1.3.2", + "send": "0.16.2" + } + }, + "setprototypeof": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.1.0.tgz", + "integrity": "sha512-BvE/TwpZX4FXExxOxZyRGQQv651MSwmWKZGqvmPcRIjDqWub67kTKuIMx43cZZrS/cBBzwBcNDWoFxt2XEFIpQ==" + }, + "source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==" + }, + "statuses": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-1.4.0.tgz", + "integrity": "sha512-zhSCtt8v2NDrRlPQpCNtw/heZLtfUDqxBM1udqikb/Hbk52LK4nQSwr10u77iopCW5LsyHpuXS0GnEc48mLeew==" + }, + "type-is": { + "version": "1.6.16", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.16.tgz", + "integrity": "sha512-HRkVv/5qY2G6I8iab9cI7v1bOIdhm94dVjQCPFElW9W+3GeDOSHmy2EBYe4VTApuzolPcmgFTN3ftVJRKR2J9Q==", + "requires": { + "media-typer": "0.3.0", + "mime-types": "~2.1.18" + } + }, + "uglify-js": { + "version": "3.5.9", + "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.5.9.tgz", + "integrity": "sha512-WpT0RqsDtAWPNJK955DEnb6xjymR8Fn0OlK4TT4pS0ASYsVPqr5ELhgwOwLCP5J5vHeJ4xmMmz3DEgdqC10JeQ==", + "optional": true, + "requires": { + "commander": "~2.20.0", + "source-map": "~0.6.1" + } + }, + "unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha1-sr9O6FFKrmFltIF4KdIbLvSZBOw=" + }, + "utils-merge": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/utils-merge/-/utils-merge-1.0.1.tgz", + "integrity": "sha1-n5VxD1CiZ5R7LMwSR0HBAoQn5xM=" + }, + "vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha1-IpnwLG3tMNSllhsLn3RSShj2NPw=" + }, + "walk": { + "version": "2.3.9", + "resolved": "https://registry.npmjs.org/walk/-/walk-2.3.9.tgz", + "integrity": "sha1-MbTbZnjyrgHDnqn7hyWpAx5Vins=", + "requires": { + "foreachasync": "^3.0.0" + } + }, + "wordwrap": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", + "integrity": "sha1-o9XabNXAvAAI03I0u68b7WMFkQc=" + } + } +} diff --git a/images/benchmarks/node/package.json b/images/benchmarks/node/package.json new file mode 100644 index 000000000..7dcadd523 --- /dev/null +++ b/images/benchmarks/node/package.json @@ -0,0 +1,19 @@ +{ + "name": "nodedum", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "dependencies": { + "express": "^4.16.4", + "hbs": "^4.0.4", + "redis": "^2.8.0", + "redis-commands": "^1.2.0", + "redis-parser": "^2.6.0", + "secure-random-string": "^1.1.0" + } +} diff --git a/images/benchmarks/redis/Dockerfile b/images/benchmarks/redis/Dockerfile new file mode 100644 index 000000000..0f17249af --- /dev/null +++ b/images/benchmarks/redis/Dockerfile @@ -0,0 +1 @@ +FROM redis:5.0.4 diff --git a/images/benchmarks/runsc/Dockerfile b/images/benchmarks/runsc/Dockerfile new file mode 100644 index 000000000..6c3aafa57 --- /dev/null +++ b/images/benchmarks/runsc/Dockerfile @@ -0,0 +1,24 @@ +FROM ubuntu:18.04 + +RUN set -x \ + && apt-get update \ + && apt-get install -y \ + wget \ + git \ + pkg-config \ + zip \ + g++ \ + zlib1g-dev \ + unzip \ + python-minimal \ + python3 \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* +RUN wget https://github.com/bazelbuild/bazel/releases/download/3.4.1/bazel-3.4.1-installer-linux-x86_64.sh +RUN chmod +x bazel-3.4.1-installer-linux-x86_64.sh +RUN ./bazel-3.4.1-installer-linux-x86_64.sh + +# Download release-20200601.0 +RUN mkdir gvisor && cd gvisor \ + && git init && git remote add origin https://github.com/google/gvisor.git \ + && git fetch --depth 1 origin a9b47390c821942d60784e308f681f213645049c && git checkout FETCH_HEAD diff --git a/images/benchmarks/tensorflow/Dockerfile b/images/benchmarks/tensorflow/Dockerfile new file mode 100644 index 000000000..7564a4ee5 --- /dev/null +++ b/images/benchmarks/tensorflow/Dockerfile @@ -0,0 +1,7 @@ +FROM tensorflow/tensorflow:1.13.2 + +RUN apt-get update \ + && apt-get install -y git +RUN git clone --depth 1 https://github.com/aymericdamien/TensorFlow-Examples.git +RUN python -m pip install -U pip setuptools +RUN python -m pip install matplotlib diff --git a/images/default/Dockerfile b/images/default/Dockerfile index 397082b02..d058b83cb 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -1,8 +1,8 @@ FROM fedora:31 # Install bazel. RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel -RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch -RUN pip install pycparser +RUN dnf install -y git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static libstdc++-static patch diffutils +RUN pip install --no-cache-dir pycparser RUN dnf install -y bazel3 # Install gcloud. RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \ diff --git a/images/hostoverlaytest/Dockerfile b/images/hostoverlaytest/Dockerfile deleted file mode 100644 index d83439e9c..000000000 --- a/images/hostoverlaytest/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM ubuntu:bionic - -WORKDIR /root -COPY . . - -RUN apt-get update && apt-get install -y gcc -RUN gcc -O2 -o test test.c diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile index 4860dd750..ba039ba15 100644 --- a/images/jekyll/Dockerfile +++ b/images/jekyll/Dockerfile @@ -10,4 +10,5 @@ RUN gem install \ jekyll-relative-links:0.6.1 \ jekyll-feed:0.13.0 \ jekyll-sitemap:1.4.0 +COPY checks.rb /checks.rb CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"] diff --git a/images/jekyll/checks.rb b/images/jekyll/checks.rb new file mode 100644 index 000000000..fc7e6b5a8 --- /dev/null +++ b/images/jekyll/checks.rb @@ -0,0 +1,36 @@ +#!/usr/local/bin/ruby +# +# HTMLProofer checks for the gVisor website. +# +require 'html-proofer' + +# NoOpenerCheck checks to make sure links with target=_blank include the +# rel=noopener attribute. +class NoOpenerCheck < ::HTMLProofer::Check + def run + @html.css('a').each do |node| + link = create_element(node) + line = node.line + + rel = link.respond_to?(:rel) ? link.rel.split(' ') : [] + + if link.respond_to?(:target) && link.target == "_blank" && !rel.include?("noopener") + return add_issue("You should set rel=noopener for links with target=_blank", line: line) + end + end + end +end + +def main() + options = { + :check_html => true, + :check_favicon => true, + :disable_external => true, + } + + HTMLProofer.check_directories(ARGV, options).run +end + +if __FILE__ == $0 + main +end diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 114b516e2..05ca5342f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -23,11 +23,13 @@ go_library( "errors.go", "eventfd.go", "exec.go", + "fadvise.go", "fcntl.go", "file.go", "file_amd64.go", "file_arm64.go", "fs.go", + "fuse.go", "futex.go", "inotify.go", "ioctl.go", @@ -71,6 +73,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/usermem", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index fa3ae5f18..192e2093b 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -46,6 +46,10 @@ const ( // TTYAUX_MAJOR is the major device number for alternate TTY devices. TTYAUX_MAJOR = 5 + // MISC_MAJOR is the major device number for non-serial mice, misc feature + // devices. + MISC_MAJOR = 10 + // UNIX98_PTY_MASTER_MAJOR is the initial major device number for // Unix98 PTY masters. UNIX98_PTY_MASTER_MAJOR = 128 diff --git a/pkg/abi/linux/fadvise.go b/pkg/abi/linux/fadvise.go new file mode 100644 index 000000000..b06ff9964 --- /dev/null +++ b/pkg/abi/linux/fadvise.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +const ( + POSIX_FADV_NORMAL = 0 + POSIX_FADV_RANDOM = 1 + POSIX_FADV_SEQUENTIAL = 2 + POSIX_FADV_WILLNEED = 3 + POSIX_FADV_DONTNEED = 4 + POSIX_FADV_NOREUSE = 5 +) diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index 6663a199c..9242e80a5 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -55,7 +55,7 @@ type Flock struct { _ [4]byte } -// Flags for F_SETOWN_EX and F_GETOWN_EX. +// Owner types for F_SETOWN_EX and F_GETOWN_EX. const ( F_OWNER_TID = 0 F_OWNER_PID = 1 diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go index 055ac1d7c..e11ca2d62 100644 --- a/pkg/abi/linux/file.go +++ b/pkg/abi/linux/file.go @@ -191,8 +191,9 @@ var DirentType = abi.ValueSet{ // Values for preadv2/pwritev2. const ( - // Note: gVisor does not implement the RWF_HIPRI feature, but the flag is - // accepted as a valid flag argument for preadv2/pwritev2. + // NOTE(b/120162627): gVisor does not implement the RWF_HIPRI feature, but + // the flag is accepted as a valid flag argument for preadv2/pwritev2 and + // silently ignored. RWF_HIPRI = 0x00000001 RWF_DSYNC = 0x00000002 RWF_SYNC = 0x00000004 diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go new file mode 100644 index 000000000..5c6ffe4a3 --- /dev/null +++ b/pkg/abi/linux/fuse.go @@ -0,0 +1,248 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +// +marshal +type FUSEOpcode uint32 + +// +marshal +type FUSEOpID uint64 + +// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. +const ( + FUSE_LOOKUP FUSEOpcode = 1 + FUSE_FORGET = 2 /* no reply */ + FUSE_GETATTR = 3 + FUSE_SETATTR = 4 + FUSE_READLINK = 5 + FUSE_SYMLINK = 6 + _ + FUSE_MKNOD = 8 + FUSE_MKDIR = 9 + FUSE_UNLINK = 10 + FUSE_RMDIR = 11 + FUSE_RENAME = 12 + FUSE_LINK = 13 + FUSE_OPEN = 14 + FUSE_READ = 15 + FUSE_WRITE = 16 + FUSE_STATFS = 17 + FUSE_RELEASE = 18 + _ + FUSE_FSYNC = 20 + FUSE_SETXATTR = 21 + FUSE_GETXATTR = 22 + FUSE_LISTXATTR = 23 + FUSE_REMOVEXATTR = 24 + FUSE_FLUSH = 25 + FUSE_INIT = 26 + FUSE_OPENDIR = 27 + FUSE_READDIR = 28 + FUSE_RELEASEDIR = 29 + FUSE_FSYNCDIR = 30 + FUSE_GETLK = 31 + FUSE_SETLK = 32 + FUSE_SETLKW = 33 + FUSE_ACCESS = 34 + FUSE_CREATE = 35 + FUSE_INTERRUPT = 36 + FUSE_BMAP = 37 + FUSE_DESTROY = 38 + FUSE_IOCTL = 39 + FUSE_POLL = 40 + FUSE_NOTIFY_REPLY = 41 + FUSE_BATCH_FORGET = 42 +) + +const ( + // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem. + // This is the minimum size Linux supports. See linux.fuse.h. + FUSE_MIN_READ_BUFFER uint32 = 8192 +) + +// FUSEHeaderIn is the header read by the daemon with each request. +// +// +marshal +type FUSEHeaderIn struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Opcode specifies the kind of operation of the request. + Opcode FUSEOpcode + + // Unique specifies the unique identifier for this request. + Unique FUSEOpID + + // NodeID is the ID of the filesystem object being operated on. + NodeID uint64 + + // UID is the UID of the requesting process. + UID uint32 + + // GID is the GID of the requesting process. + GID uint32 + + // PID is the PID of the requesting process. + PID uint32 + + _ uint32 +} + +// FUSEHeaderOut is the header written by the daemon when it processes +// a request and wants to send a reply (almost all operations require a +// reply; if they do not, this will be explicitly documented). +// +// +marshal +type FUSEHeaderOut struct { + // Len specifies the total length of the data, including this header. + Len uint32 + + // Error specifies the error that occurred (0 if none). + Error int32 + + // Unique specifies the unique identifier of the corresponding request. + Unique FUSEOpID +} + +// FUSEWriteIn is the header written by a daemon when it makes a +// write request to the FUSE filesystem. +// +// +marshal +type FUSEWriteIn struct { + // Fh specifies the file handle that is being written to. + Fh uint64 + + // Offset is the offset of the write. + Offset uint64 + + // Size is the size of data being written. + Size uint32 + + // WriteFlags is the flags used during the write. + WriteFlags uint32 + + // LockOwner is the ID of the lock owner. + LockOwner uint64 + + // Flags is the flags for the request. + Flags uint32 + + _ uint32 +} + +// FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h. +const ( + FUSE_ASYNC_READ = 1 << 0 + FUSE_POSIX_LOCKS = 1 << 1 + FUSE_FILE_OPS = 1 << 2 + FUSE_ATOMIC_O_TRUNC = 1 << 3 + FUSE_EXPORT_SUPPORT = 1 << 4 + FUSE_BIG_WRITES = 1 << 5 + FUSE_DONT_MASK = 1 << 6 + FUSE_SPLICE_WRITE = 1 << 7 + FUSE_SPLICE_MOVE = 1 << 8 + FUSE_SPLICE_READ = 1 << 9 + FUSE_FLOCK_LOCKS = 1 << 10 + FUSE_HAS_IOCTL_DIR = 1 << 11 + FUSE_AUTO_INVAL_DATA = 1 << 12 + FUSE_DO_READDIRPLUS = 1 << 13 + FUSE_READDIRPLUS_AUTO = 1 << 14 + FUSE_ASYNC_DIO = 1 << 15 + FUSE_WRITEBACK_CACHE = 1 << 16 + FUSE_NO_OPEN_SUPPORT = 1 << 17 + FUSE_PARALLEL_DIROPS = 1 << 18 + FUSE_HANDLE_KILLPRIV = 1 << 19 + FUSE_POSIX_ACL = 1 << 20 + FUSE_ABORT_ERROR = 1 << 21 + FUSE_MAX_PAGES = 1 << 22 + FUSE_CACHE_SYMLINKS = 1 << 23 + FUSE_NO_OPENDIR_SUPPORT = 1 << 24 + FUSE_EXPLICIT_INVAL_DATA = 1 << 25 + FUSE_MAP_ALIGNMENT = 1 << 26 +) + +// currently supported FUSE protocol version numbers. +const ( + FUSE_KERNEL_VERSION = 7 + FUSE_KERNEL_MINOR_VERSION = 31 +) + +// FUSEInitIn is the request sent by the kernel to the daemon, +// to negotiate the version and flags. +// +// +marshal +type FUSEInitIn struct { + // Major version supported by kernel. + Major uint32 + + // Minor version supported by the kernel. + Minor uint32 + + // MaxReadahead is the maximum number of bytes to read-ahead + // decided by the kernel. + MaxReadahead uint32 + + // Flags of this init request. + Flags uint32 +} + +// FUSEInitOut is the reply sent by the daemon to the kernel +// for FUSEInitIn. +// +// +marshal +type FUSEInitOut struct { + // Major version supported by daemon. + Major uint32 + + // Minor version supported by daemon. + Minor uint32 + + // MaxReadahead is the maximum number of bytes to read-ahead. + // Decided by the daemon, after receiving the value from kernel. + MaxReadahead uint32 + + // Flags of this init reply. + Flags uint32 + + // MaxBackground is the maximum number of pending background requests + // that the daemon wants. + MaxBackground uint16 + + // CongestionThreshold is the daemon-decided threshold for + // the number of the pending background requests. + CongestionThreshold uint16 + + // MaxWrite is the daemon's maximum size of a write buffer. + // Kernel adjusts it to the minimum (fuse/init.go:fuseMinMaxWrite). + // if the value from daemon is too small. + MaxWrite uint32 + + // TimeGran is the daemon's time granularity for mtime and ctime metadata. + // The unit is nanosecond. + // Value should be power of 10. + // 1 indicates full nanosecond granularity support. + TimeGran uint32 + + // MaxPages is the daemon's maximum number of pages for one write operation. + // Kernel adjusts it to the maximum (fuse/init.go:FUSE_MAX_MAX_PAGES). + // if the value from daemon is too large. + MaxPages uint16 + + // MapAlignment is an unknown field and not used by this package at this moment. + // Use as a placeholder to be consistent with the FUSE protocol. + MapAlignment uint16 + + _ [8]uint32 +} diff --git a/pkg/abi/linux/futex.go b/pkg/abi/linux/futex.go index 08bfde3b5..8138088a6 100644 --- a/pkg/abi/linux/futex.go +++ b/pkg/abi/linux/futex.go @@ -60,3 +60,21 @@ const ( FUTEX_WAITERS = 0x80000000 FUTEX_OWNER_DIED = 0x40000000 ) + +// FUTEX_BITSET_MATCH_ANY has all bits set. +const FUTEX_BITSET_MATCH_ANY = 0xffffffff + +// ROBUST_LIST_LIMIT protects against a deliberately circular list. +const ROBUST_LIST_LIMIT = 2048 + +// RobustListHead corresponds to Linux's struct robust_list_head. +// +// +marshal +type RobustListHead struct { + List uint64 + FutexOffset uint64 + ListOpPending uint64 +} + +// SizeOfRobustListHead is the size of a RobustListHead struct. +var SizeOfRobustListHead = (*RobustListHead)(nil).SizeBytes() diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2062e6a4b..2c5e56ae5 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -67,10 +67,29 @@ const ( // ioctl(2) requests provided by uapi/linux/sockios.h const ( - SIOCGIFMEM = 0x891f - SIOCGIFPFLAGS = 0x8935 - SIOCGMIIPHY = 0x8947 - SIOCGMIIREG = 0x8948 + SIOCGIFNAME = 0x8910 + SIOCGIFCONF = 0x8912 + SIOCGIFFLAGS = 0x8913 + SIOCGIFADDR = 0x8915 + SIOCGIFDSTADDR = 0x8917 + SIOCGIFBRDADDR = 0x8919 + SIOCGIFNETMASK = 0x891b + SIOCGIFMETRIC = 0x891d + SIOCGIFMTU = 0x8921 + SIOCGIFMEM = 0x891f + SIOCGIFHWADDR = 0x8927 + SIOCGIFINDEX = 0x8933 + SIOCGIFPFLAGS = 0x8935 + SIOCGIFTXQLEN = 0x8942 + SIOCETHTOOL = 0x8946 + SIOCGMIIPHY = 0x8947 + SIOCGMIIREG = 0x8948 + SIOCGIFMAP = 0x8970 +) + +// ioctl(2) requests provided by uapi/asm-generic/sockios.h +const ( + SIOCGSTAMP = 0x8906 ) // ioctl(2) directions. Used to calculate requests number. diff --git a/pkg/abi/linux/netdevice.go b/pkg/abi/linux/netdevice.go index 7866352b4..0faf015c7 100644 --- a/pkg/abi/linux/netdevice.go +++ b/pkg/abi/linux/netdevice.go @@ -22,6 +22,8 @@ const ( ) // IFReq is an interface request. +// +// +marshal type IFReq struct { // IFName is an encoded name, normally null-terminated. This should be // accessed via the Name and SetName functions. @@ -79,6 +81,8 @@ type IFMap struct { // IFConf is used to return a list of interfaces and their addresses. See // netdevice(7) and struct ifconf for more detail on its use. +// +// +marshal type IFConf struct { Len int32 _ [4]byte // Pad to sizeof(struct ifconf). diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 46d8b0b42..a91f9f018 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -14,6 +14,14 @@ package linux +import ( + "io" + + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + // This file contains structures required to support netfilter, specifically // the iptables tool. @@ -76,6 +84,8 @@ const ( // IPTEntry is an iptable rule. It corresponds to struct ipt_entry in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTEntry struct { // IP is used to filter packets based on the IP header. IP IPTIP @@ -112,21 +122,41 @@ type IPTEntry struct { // SizeOfIPTEntry is the size of an IPTEntry. const SizeOfIPTEntry = 112 -// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. This -// struct marshaled via the binary package to write an IPTEntry to userspace. +// KernelIPTEntry is identical to IPTEntry, but includes the Elems field. +// KernelIPTEntry itself is not Marshallable but it implements some methods of +// marshal.Marshallable that help in other implementations of Marshallable. type KernelIPTEntry struct { - IPTEntry + Entry IPTEntry // Elems holds the data for all this rule's matches followed by the // target. It is variable length -- users have to iterate over any // matches and use TargetOffset and NextOffset to make sense of the // data. - Elems []byte + Elems primitive.ByteSlice +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTEntry) SizeBytes() int { + return ke.Entry.SizeBytes() + ke.Elems.SizeBytes() +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { + ke.Entry.MarshalBytes(dst) + ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { + ke.Entry.UnmarshalBytes(src) + ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) } // IPTIP contains information for matching a packet's IP header. // It corresponds to struct ipt_ip in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTIP struct { // Src is the source IP address. Src InetAddr @@ -189,6 +219,8 @@ const SizeOfIPTIP = 84 // XTCounters holds packet and byte counts for a rule. It corresponds to struct // xt_counters in include/uapi/linux/netfilter/x_tables.h. +// +// +marshal type XTCounters struct { // Pcnt is the packet count. Pcnt uint64 @@ -321,6 +353,8 @@ const SizeOfXTRedirectTarget = 56 // IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds // to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetinfo struct { Name TableName ValidHooks uint32 @@ -336,6 +370,8 @@ const SizeOfIPTGetinfo = 84 // IPTGetEntries is the argument for the IPT_SO_GET_ENTRIES sockopt. It // corresponds to struct ipt_get_entries in // include/uapi/linux/netfilter_ipv4/ip_tables.h. +// +// +marshal type IPTGetEntries struct { Name TableName Size uint32 @@ -350,13 +386,103 @@ type IPTGetEntries struct { const SizeOfIPTGetEntries = 40 // KernelIPTGetEntries is identical to IPTGetEntries, but includes the -// Entrytable field. This struct marshaled via the binary package to write an -// KernelIPTGetEntries to userspace. +// Entrytable field. This has been manually made marshal.Marshallable since it +// is dynamically sized. type KernelIPTGetEntries struct { IPTGetEntries Entrytable []KernelIPTEntry } +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (ke *KernelIPTGetEntries) SizeBytes() int { + res := ke.IPTGetEntries.SizeBytes() + for _, entry := range ke.Entrytable { + res += entry.SizeBytes() + } + return res +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { + ke.IPTGetEntries.MarshalBytes(dst) + marshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) + marshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { + ke.IPTGetEntries.UnmarshalBytes(src) + unmarshalledUntil := ke.IPTGetEntries.SizeBytes() + for i := 0; i < len(ke.Entrytable); i++ { + ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) + unmarshalledUntil += ke.Entrytable[i].SizeBytes() + } +} + +// Packed implements marshal.Marshallable.Packed. +func (ke *KernelIPTGetEntries) Packed() bool { + // KernelIPTGetEntries isn't packed because the ke.Entrytable contains an + // indirection to the actual data we want to marshal (the slice data + // pointer), and the memory for KernelIPTGetEntries contains the slice + // header which we don't want to marshal. + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (ke *KernelIPTGetEntries) MarshalUnsafe(dst []byte) { + // Fall back to safe Marshal because the type in not packed. + ke.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { + // Fall back to safe Unmarshal because the type in not packed. + ke.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := task.CopyInBytes(addr, buf) // escapes: okay. + // Unmarshal unconditionally. If we had a short copy-in, this results in a + // partially unmarshalled struct. + ke.UnmarshalBytes(buf) // escapes: fallback. + return length, err +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall + // back to MarshalBytes. + return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) +} + +func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { + buf := task.CopyScratchBuffer(ke.SizeBytes()) + ke.MarshalBytes(buf) + return buf +} + +// WriteTo implements io.WriterTo.WriteTo. +func (ke *KernelIPTGetEntries) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, ke.SizeBytes()) + ke.MarshalBytes(buf) + length, err := w.Write(buf) + return int64(length), err +} + +var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) + // IPTReplace is the argument for the IPT_SO_SET_REPLACE sockopt. It // corresponds to struct ipt_replace in // include/uapi/linux/netfilter_ipv4/ip_tables.h. @@ -374,12 +500,6 @@ type IPTReplace struct { // Entries [0]IPTEntry } -// KernelIPTReplace is identical to IPTReplace, but includes the Entries field. -type KernelIPTReplace struct { - IPTReplace - Entries [0]IPTEntry -} - // SizeOfIPTReplace is the size of an IPTReplace. const SizeOfIPTReplace = 96 @@ -392,6 +512,8 @@ func (en ExtensionName) String() string { } // TableName holds the name of a netfilter table. +// +// +marshal type TableName [XT_TABLE_MAXNAMELEN]byte // String implements fmt.Stringer. diff --git a/pkg/abi/linux/netlink_route.go b/pkg/abi/linux/netlink_route.go index 40bec566c..ceda0a8d3 100644 --- a/pkg/abi/linux/netlink_route.go +++ b/pkg/abi/linux/netlink_route.go @@ -187,6 +187,8 @@ const ( // Device types, from uapi/linux/if_arp.h. const ( + ARPHRD_NONE = 65534 + ARPHRD_ETHER = 1 ARPHRD_LOOPBACK = 772 ) diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 4a14ef691..c24a8216e 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -134,6 +134,15 @@ const ( SHUT_RDWR = 2 ) +// Packet types from <linux/if_packet.h> +const ( + PACKET_HOST = 0 // To us + PACKET_BROADCAST = 1 // To all + PACKET_MULTICAST = 2 // To group + PACKET_OTHERHOST = 3 // To someone else + PACKET_OUTGOING = 4 // Outgoing of any type +) + // Socket options from socket.h. const ( SO_DEBUG = 1 @@ -225,6 +234,8 @@ const ( const SockAddrMax = 128 // InetAddr is struct in_addr, from uapi/linux/in.h. +// +// +marshal type InetAddr [4]byte // SockAddrInet is struct sockaddr_in, from uapi/linux/in.h. @@ -294,6 +305,8 @@ func (s *SockAddrUnix) implementsSockAddr() {} func (s *SockAddrNetlink) implementsSockAddr() {} // Linger is struct linger, from include/linux/socket.h. +// +// +marshal type Linger struct { OnOff int32 Linger int32 @@ -308,6 +321,8 @@ const SizeOfLinger = 8 // the end of this struct or within existing unusued space, so its size grows // over time. The current iteration is based on linux v4.17. New versions are // always backwards compatible. +// +// +marshal type TCPInfo struct { State uint8 CaState uint8 @@ -405,6 +420,8 @@ var SizeOfControlMessageHeader = int(binary.Size(ControlMessageHeader{})) // A ControlMessageCredentials is an SCM_CREDENTIALS socket control message. // // ControlMessageCredentials represents struct ucred from linux/socket.h. +// +// +marshal type ControlMessageCredentials struct { PID int32 UID uint32 diff --git a/pkg/compressio/compressio.go b/pkg/compressio/compressio.go index 5f52cbe74..b094c5662 100644 --- a/pkg/compressio/compressio.go +++ b/pkg/compressio/compressio.go @@ -346,20 +346,22 @@ func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { } } -// reader chunks reads and decompresses. -type reader struct { +// Reader is a compressed reader. +type Reader struct { pool // in is the source. in io.Reader } +var _ io.Reader = (*Reader)(nil) + // NewReader returns a new compressed reader. If key is non-nil, the data stream // is assumed to contain expected hash values, which will be compared against // hash values computed from the compressed bytes. See package comments for // details. -func NewReader(in io.Reader, key []byte) (io.Reader, error) { - r := &reader{ +func NewReader(in io.Reader, key []byte) (*Reader, error) { + r := &Reader{ in: in, } @@ -394,8 +396,19 @@ var errNewBuffer = errors.New("buffer ready") // ErrHashMismatch is returned if the hash does not match. var ErrHashMismatch = errors.New("hash mismatch") +// ReadByte implements wire.Reader.ReadByte. +func (r *Reader) ReadByte() (byte, error) { + var p [1]byte + n, err := r.Read(p[:]) + if n != 1 { + return p[0], err + } + // Suppress EOF. + return p[0], nil +} + // Read implements io.Reader.Read. -func (r *reader) Read(p []byte) (int, error) { +func (r *Reader) Read(p []byte) (int, error) { r.mu.Lock() defer r.mu.Unlock() @@ -551,8 +564,8 @@ func (r *reader) Read(p []byte) (int, error) { return done, nil } -// writer chunks and schedules writes. -type writer struct { +// Writer is a compressed writer. +type Writer struct { pool // out is the underlying writer. @@ -562,6 +575,8 @@ type writer struct { closed bool } +var _ io.Writer = (*Writer)(nil) + // NewWriter returns a new compressed writer. If key is non-nil, hash values are // generated and written out for compressed bytes. See package comments for // details. @@ -569,8 +584,8 @@ type writer struct { // The recommended chunkSize is on the order of 1M. Extra memory may be // buffered (in the form of read-ahead, or buffered writes), and is limited to // O(chunkSize * [1+GOMAXPROCS]). -func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.WriteCloser, error) { - w := &writer{ +func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { + w := &Writer{ pool: pool{ chunkSize: chunkSize, buf: bufPool.Get().(*bytes.Buffer), @@ -597,7 +612,7 @@ func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (io.Write } // flush writes a single buffer. -func (w *writer) flush(c *chunk) error { +func (w *Writer) flush(c *chunk) error { // Prefix each chunk with a length; this allows the reader to safely // limit reads while buffering. l := uint32(c.compressed.Len()) @@ -624,8 +639,23 @@ func (w *writer) flush(c *chunk) error { return nil } +// WriteByte implements wire.Writer.WriteByte. +// +// Note that this implementation is necessary on the object itself, as an +// interface-based dispatch cannot tell whether the array backing the slice +// escapes, therefore the all bytes written will generate an escape. +func (w *Writer) WriteByte(b byte) error { + var p [1]byte + p[0] = b + n, err := w.Write(p[:]) + if n != 1 { + return err + } + return nil +} + // Write implements io.Writer.Write. -func (w *writer) Write(p []byte) (int, error) { +func (w *Writer) Write(p []byte) (int, error) { w.mu.Lock() defer w.mu.Unlock() @@ -710,7 +740,7 @@ func (w *writer) Write(p []byte) (int, error) { } // Close implements io.Closer.Close. -func (w *writer) Close() error { +func (w *Writer) Close() error { w.mu.Lock() defer w.mu.Unlock() diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD index 798a65eca..35683fe98 100644 --- a/pkg/gohacks/BUILD +++ b/pkg/gohacks/BUILD @@ -7,5 +7,6 @@ go_library( srcs = [ "gohacks_unsafe.go", ], + stateify = False, visibility = ["//:sandbox"], ) diff --git a/pkg/ilist/list.go b/pkg/ilist/list.go index 0d07da3b1..f4a4c33d3 100644 --- a/pkg/ilist/list.go +++ b/pkg/ilist/list.go @@ -90,7 +90,7 @@ func (l *List) Back() Element { // // NOTE: This is an O(n) operation. func (l *List) Len() (count int) { - for e := l.Front(); e != nil; e = e.Next() { + for e := l.Front(); e != nil; e = (ElementMapper{}.linkerFor(e)).Next() { count++ } return count @@ -182,13 +182,13 @@ func (l *List) Remove(e Element) { if prev != nil { ElementMapper{}.linkerFor(prev).SetNext(next) - } else { + } else if l.head == e { l.head = next } if next != nil { ElementMapper{}.linkerFor(next).SetPrev(prev) - } else { + } else if l.tail == e { l.tail = prev } diff --git a/pkg/iovec/BUILD b/pkg/iovec/BUILD new file mode 100644 index 000000000..eda82cfc1 --- /dev/null +++ b/pkg/iovec/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "iovec", + srcs = ["iovec.go"], + visibility = ["//:sandbox"], + deps = ["//pkg/abi/linux"], +) + +go_test( + name = "iovec_test", + size = "small", + srcs = ["iovec_test.go"], + library = ":iovec", + deps = ["@org_golang_x_sys//unix:go_default_library"], +) diff --git a/pkg/iovec/iovec.go b/pkg/iovec/iovec.go new file mode 100644 index 000000000..dd70fe80f --- /dev/null +++ b/pkg/iovec/iovec.go @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package iovec provides helpers to interact with vectorized I/O on host +// system. +package iovec + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +// MaxIovs is the maximum number of iovecs host platform can accept. +var MaxIovs = linux.UIO_MAXIOV + +// Builder is a builder for slice of syscall.Iovec. +type Builder struct { + iovec []syscall.Iovec + storage [8]syscall.Iovec + + // overflow tracks the last buffer when iovec length is at MaxIovs. + overflow []byte +} + +// Add adds buf to b preparing to be written. Zero-length buf won't be added. +func (b *Builder) Add(buf []byte) { + if len(buf) == 0 { + return + } + if b.iovec == nil { + b.iovec = b.storage[:0] + } + if len(b.iovec) >= MaxIovs { + b.addByAppend(buf) + return + } + b.iovec = append(b.iovec, syscall.Iovec{ + Base: &buf[0], + Len: uint64(len(buf)), + }) + // Keep the last buf if iovec is at max capacity. We will need to append to it + // for later bufs. + if len(b.iovec) == MaxIovs { + n := len(buf) + b.overflow = buf[:n:n] + } +} + +func (b *Builder) addByAppend(buf []byte) { + b.overflow = append(b.overflow, buf...) + b.iovec[len(b.iovec)-1] = syscall.Iovec{ + Base: &b.overflow[0], + Len: uint64(len(b.overflow)), + } +} + +// Build returns the final Iovec slice. The length of returned iovec will not +// excceed MaxIovs. +func (b *Builder) Build() []syscall.Iovec { + return b.iovec +} diff --git a/pkg/iovec/iovec_test.go b/pkg/iovec/iovec_test.go new file mode 100644 index 000000000..a3900c299 --- /dev/null +++ b/pkg/iovec/iovec_test.go @@ -0,0 +1,121 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package iovec + +import ( + "bytes" + "fmt" + "syscall" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func TestBuilderEmpty(t *testing.T) { + var builder Builder + iovecs := builder.Build() + if got, want := len(iovecs), 0; got != want { + t.Errorf("len(iovecs) = %d, want %d", got, want) + } +} + +func TestBuilderBuild(t *testing.T) { + a := []byte{1, 2} + b := []byte{3, 4, 5} + + var builder Builder + builder.Add(a) + builder.Add(b) + builder.Add(nil) // Nil slice won't be added. + builder.Add([]byte{}) // Empty slice won't be added. + iovecs := builder.Build() + + if got, want := len(iovecs), 2; got != want { + t.Fatalf("len(iovecs) = %d, want %d", got, want) + } + for i, data := range [][]byte{a, b} { + if got, want := *iovecs[i].Base, data[0]; got != want { + t.Fatalf("*iovecs[%d].Base = %d, want %d", i, got, want) + } + if got, want := iovecs[i].Len, uint64(len(data)); got != want { + t.Fatalf("iovecs[%d].Len = %d, want %d", i, got, want) + } + } +} + +func TestBuilderBuildMaxIov(t *testing.T) { + for _, test := range []struct { + numIov int + }{ + { + numIov: MaxIovs - 1, + }, + { + numIov: MaxIovs, + }, + { + numIov: MaxIovs + 1, + }, + { + numIov: MaxIovs + 10, + }, + } { + name := fmt.Sprintf("numIov=%v", test.numIov) + t.Run(name, func(t *testing.T) { + var data []byte + var builder Builder + for i := 0; i < test.numIov; i++ { + buf := []byte{byte(i)} + builder.Add(buf) + data = append(data, buf...) + } + iovec := builder.Build() + + // Check the expected length of iovec. + wantNum := test.numIov + if wantNum > MaxIovs { + wantNum = MaxIovs + } + if got, want := len(iovec), wantNum; got != want { + t.Errorf("len(iovec) = %d, want %d", got, want) + } + + // Test a real read-write. + var fds [2]int + if err := unix.Pipe(fds[:]); err != nil { + t.Fatalf("Pipe: %v", err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + wrote, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fds[1]), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if int(wrote) != len(data) || e != 0 { + t.Fatalf("writev: %v, %v; want %v, 0", wrote, e, len(data)) + } + + got := make([]byte, len(data)) + if n, err := syscall.Read(fds[0], got); n != len(got) || err != nil { + t.Fatalf("read: %v, %v; want %v, nil", n, err, len(got)) + } + + if !bytes.Equal(got, data) { + t.Errorf("read: got data %v, want %v", got, data) + } + }) + } +} diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 57b89ad7d..2cb59f934 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -2506,7 +2506,7 @@ type msgFactory struct { var msgRegistry registry type registry struct { - factories [math.MaxUint8]msgFactory + factories [math.MaxUint8 + 1]msgFactory // largestFixedSize is computed so that given some message size M, you can // compute the maximum payload size (e.g. for Twrite, Rread) with diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 28d851ff5..122c457d2 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -1091,6 +1091,19 @@ type AllocateMode struct { Unshare bool } +// ToAllocateMode returns an AllocateMode from a fallocate(2) mode. +func ToAllocateMode(mode uint64) AllocateMode { + return AllocateMode{ + KeepSize: mode&unix.FALLOC_FL_KEEP_SIZE != 0, + PunchHole: mode&unix.FALLOC_FL_PUNCH_HOLE != 0, + NoHideStale: mode&unix.FALLOC_FL_NO_HIDE_STALE != 0, + CollapseRange: mode&unix.FALLOC_FL_COLLAPSE_RANGE != 0, + ZeroRange: mode&unix.FALLOC_FL_ZERO_RANGE != 0, + InsertRange: mode&unix.FALLOC_FL_INSERT_RANGE != 0, + Unshare: mode&unix.FALLOC_FL_UNSHARE_RANGE != 0, + } +} + // ToLinux converts to a value compatible with fallocate(2)'s mode. func (a *AllocateMode) ToLinux() uint32 { rv := uint32(0) diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go index daba8b172..fd95eb2d2 100644 --- a/pkg/sentry/arch/arch_aarch64.go +++ b/pkg/sentry/arch/arch_aarch64.go @@ -28,7 +28,14 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs + + // TPIDR_EL0 is the EL0 Read/Write Software Thread ID Register. + TPIDR_EL0 uint64 +} const ( // SyscallWidth is the width of insturctions. @@ -101,9 +108,6 @@ type State struct { // Our floating point state. aarch64FPState `state:"wait"` - // TLS pointer - TPValue uint64 - // FeatureSet is a pointer to the currently active feature set. FeatureSet *cpuid.FeatureSet @@ -157,7 +161,6 @@ func (s *State) Fork() State { return State{ Regs: s.Regs, aarch64FPState: s.aarch64FPState.fork(), - TPValue: s.TPValue, FeatureSet: s.FeatureSet, OrigR0: s.OrigR0, } @@ -241,18 +244,18 @@ func (s *State) ptraceGetRegs() Registers { return s.Regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } regs.UnmarshalUnsafe(buf) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // PtraceGetFPRegs implements Context.PtraceGetFPRegs. @@ -278,7 +281,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -291,7 +294,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 3b3a0a272..1c3e3c14c 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -300,7 +300,7 @@ func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { // PTRACE_PEEKUSER and PTRACE_POKEUSER are only effective on regs and // u_debugreg, returning 0 or silently no-oping for other fields // respectively. - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) @@ -315,7 +315,7 @@ func (c *context64) PtracePokeUser(addr, data uintptr) error { if addr&7 != 0 || addr >= userStructSize { return syscall.EIO } - if addr < uintptr(registersSize) { + if addr < uintptr(ptraceRegistersSize) { regs := c.ptraceGetRegs() buf := make([]byte, regs.SizeBytes()) regs.MarshalUnsafe(buf) diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index ada7ac7b8..cabbf60e0 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -142,7 +142,7 @@ func (c *context64) SetStack(value uintptr) { // TLS returns the current TLS pointer. func (c *context64) TLS() uintptr { - return uintptr(c.TPValue) + return uintptr(c.Regs.TPIDR_EL0) } // SetTLS sets the current TLS pointer. Returns false if value is invalid. @@ -151,7 +151,7 @@ func (c *context64) SetTLS(value uintptr) bool { return false } - c.TPValue = uint64(value) + c.Regs.TPIDR_EL0 = uint64(value) return true } diff --git a/pkg/sentry/arch/arch_x86.go b/pkg/sentry/arch/arch_x86.go index dc458b37f..b9405b320 100644 --- a/pkg/sentry/arch/arch_x86.go +++ b/pkg/sentry/arch/arch_x86.go @@ -31,7 +31,11 @@ import ( ) // Registers represents the CPU registers for this architecture. -type Registers = linux.PtraceRegs +// +// +stateify savable +type Registers struct { + linux.PtraceRegs +} // System-related constants for x86. const ( @@ -311,12 +315,12 @@ func (s *State) ptraceGetRegs() Registers { return regs } -var registersSize = (*Registers)(nil).SizeBytes() +var ptraceRegistersSize = (*linux.PtraceRegs)(nil).SizeBytes() // PtraceSetRegs implements Context.PtraceSetRegs. func (s *State) PtraceSetRegs(src io.Reader) (int, error) { var regs Registers - buf := make([]byte, registersSize) + buf := make([]byte, ptraceRegistersSize) if _, err := io.ReadFull(src, buf); err != nil { return 0, err } @@ -374,7 +378,7 @@ func (s *State) PtraceSetRegs(src io.Reader) (int, error) { } regs.Eflags = (s.Regs.Eflags &^ eflagsPtraceMutable) | (regs.Eflags & eflagsPtraceMutable) s.Regs = regs - return registersSize, nil + return ptraceRegistersSize, nil } // isUserSegmentSelector returns true if the given segment selector specifies a @@ -543,7 +547,7 @@ const ( func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceGetRegs(dst) @@ -563,7 +567,7 @@ func (s *State) PtraceGetRegSet(regset uintptr, dst io.Writer, maxlen int) (int, func (s *State) PtraceSetRegSet(regset uintptr, src io.Reader, maxlen int) (int, error) { switch regset { case _NT_PRSTATUS: - if maxlen < registersSize { + if maxlen < ptraceRegistersSize { return 0, syserror.EFAULT } return s.PtraceSetRegs(src) diff --git a/pkg/sentry/devices/ttydev/BUILD b/pkg/sentry/devices/ttydev/BUILD new file mode 100644 index 000000000..12e49b58a --- /dev/null +++ b/pkg/sentry/devices/ttydev/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "ttydev", + srcs = ["ttydev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/vfs", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/devices/ttydev/ttydev.go b/pkg/sentry/devices/ttydev/ttydev.go new file mode 100644 index 000000000..fbb7fd92c --- /dev/null +++ b/pkg/sentry/devices/ttydev/ttydev.go @@ -0,0 +1,91 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ttydev implements devices for /dev/tty and (eventually) +// /dev/console. +// +// TODO(b/159623826): Support /dev/console. +package ttydev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +const ( + // See drivers/tty/tty_io.c:tty_init(). + ttyDevMinor = 0 + consoleDevMinor = 1 +) + +// ttyDevice implements vfs.Device for /dev/tty. +type ttyDevice struct{} + +// Open implements vfs.Device.Open. +func (ttyDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &ttyFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// ttyFD implements vfs.FileDescriptionImpl for /dev/tty. +type ttyFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *ttyFD) Release() {} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *ttyFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *ttyFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *ttyFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return src.NumBytes(), nil +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *ttyFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return src.NumBytes(), nil +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, ttyDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "tty", + }) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "tty", vfs.CharDevice, linux.TTYAUX_MAJOR, ttyDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/devices/tundev/BUILD b/pkg/sentry/devices/tundev/BUILD new file mode 100644 index 000000000..71c59287c --- /dev/null +++ b/pkg/sentry/devices/tundev/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "tundev", + srcs = ["tundev.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/inet", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netstack", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/tcpip/link/tun", + "//pkg/usermem", + "//pkg/waiter", + ], +) diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go new file mode 100644 index 000000000..dfbd069af --- /dev/null +++ b/pkg/sentry/devices/tundev/tundev.go @@ -0,0 +1,178 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tundev implements the /dev/net/tun device. +package tundev + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip/link/tun" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + netTunDevMajor = 10 + netTunDevMinor = 200 +) + +// tunDevice implements vfs.Device for /dev/net/tun. +type tunDevice struct{} + +// Open implements vfs.Device.Open. +func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &tunFD{} + if err := fd.vfsfd.Init(fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +type tunFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + device tun.Device +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + request := args[1].Uint() + data := args[2].Pointer() + + switch request { + case linux.TUNSETIFF: + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + if !t.HasCapability(linux.CAP_NET_ADMIN) { + return 0, syserror.EPERM + } + stack, ok := t.NetworkContext().(*netstack.Stack) + if !ok { + return 0, syserror.EINVAL + } + + var req linux.IFReq + if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return 0, err + } + flags := usermem.ByteOrder.Uint16(req.Data[:]) + return 0, fd.device.SetIff(stack.Stack, req.Name(), flags) + + case linux.TUNGETIFF: + var req linux.IFReq + + copy(req.IFName[:], fd.device.Name()) + + // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when + // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c. + flags := fd.device.Flags() | linux.IFF_NOFILTER + usermem.ByteOrder.PutUint16(req.Data[:], flags) + + _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err + + default: + return 0, syserror.ENOTTY + } +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *tunFD) Release() { + fd.device.Release() +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *tunFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return fd.Read(ctx, dst, opts) +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *tunFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + data, err := fd.device.Read() + if err != nil { + return 0, err + } + n, err := dst.CopyOut(ctx, data) + if n > 0 && n < len(data) { + // Not an error for partial copying. Packet truncated. + err = nil + } + return int64(n), err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *tunFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return fd.Write(ctx, src, opts) +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *tunFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + data := make([]byte, src.NumBytes()) + if _, err := src.CopyIn(ctx, data); err != nil { + return 0, err + } + return fd.device.Write(data) +} + +// Readiness implements watier.Waitable.Readiness. +func (fd *tunFD) Readiness(mask waiter.EventMask) waiter.EventMask { + return fd.device.Readiness(mask) +} + +// EventRegister implements watier.Waitable.EventRegister. +func (fd *tunFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.device.EventRegister(e, mask) +} + +// EventUnregister implements watier.Waitable.EventUnregister. +func (fd *tunFD) EventUnregister(e *waiter.Entry) { + fd.device.EventUnregister(e) +} + +// isNetTunSupported returns whether /dev/net/tun device is supported for s. +func isNetTunSupported(s inet.Stack) bool { + _, ok := s.(*netstack.Stack) + return ok +} + +// Register registers all devices implemented by this package in vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + return vfsObj.RegisterDevice(vfs.CharDevice, netTunDevMajor, netTunDevMinor, tunDevice{}, &vfs.RegisterDeviceOptions{}) +} + +// CreateDevtmpfsFiles creates device special files in dev representing all +// devices implemented by this package. +func CreateDevtmpfsFiles(ctx context.Context, dev *devtmpfs.Accessor) error { + return dev.CreateDeviceFile(ctx, "net/tun", vfs.CharDevice, netTunDevMajor, netTunDevMinor, 0666 /* mode */) +} diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index beba0f771..f5537411e 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -160,6 +160,7 @@ type FileOperations interface { // refer. // // Preconditions: The AddressSpace (if any) that io refers to is activated. + // Must only be called from a task goroutine. Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) } diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD index 789369220..5fb419bcd 100644 --- a/pkg/sentry/fs/fsutil/BUILD +++ b/pkg/sentry/fs/fsutil/BUILD @@ -8,7 +8,6 @@ go_template_instance( out = "dirty_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "Dirty", @@ -25,14 +24,14 @@ go_template_instance( name = "frame_ref_set_impl", out = "frame_ref_set_impl.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "fsutil", prefix = "FrameRef", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "uint64", "Functions": "FrameRefSetFunctions", }, @@ -43,7 +42,6 @@ go_template_instance( out = "file_range_set_impl.go", imports = { "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", }, package = "fsutil", prefix = "FileRange", @@ -86,7 +84,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/state", diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go index c6cd45087..2c9446c1d 100644 --- a/pkg/sentry/fs/fsutil/dirty_set.go +++ b/pkg/sentry/fs/fsutil/dirty_set.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) { // repeatedly until all bytes have been written. max is the true size of the // cached object; offsets beyond max will not be passed to writeAt, even if // they are marked dirty. -func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { var changedDirty bool defer func() { if changedDirty { @@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet // successful partial write, SyncDirtyAll will call it repeatedly until all // bytes have been written. max is the true size of the cached object; offsets // beyond max will not be passed to writeAt, even if they are marked dirty. -func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { dseg := dirty.FirstSegment() for dseg.Ok() { if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil { @@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max } // Preconditions: mr must be page-aligned. -func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { +func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error { for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() { wbr := cseg.Range().Intersect(mr) if max < wbr.Start { diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index 5643cdac9..bbafebf03 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -23,13 +23,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/usermem" ) // FileRangeSet maps offsets into a memmap.Mappable to offsets into a -// platform.File. It is used to implement Mappables that store data in +// memmap.File. It is used to implement Mappables that store data in // sparsely-allocated memory. // // type FileRangeSet <generated by go_generics> @@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli } // FileRange returns the FileRange mapped by seg. -func (seg FileRangeIterator) FileRange() platform.FileRange { +func (seg FileRangeIterator) FileRange() memmap.FileRange { return seg.FileRangeOf(seg.Range()) } // FileRangeOf returns the FileRange mapped by mr. // // Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. -func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange { +func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) - return platform.FileRange{frstart, frstart + mr.Length()} + return memmap.FileRange{frstart, frstart + mr.Length()} } // Fill attempts to ensure that all memmap.Mappable offsets in required are -// mapped to a platform.File offset, by allocating from mf with the given +// mapped to a memmap.File offset, by allocating from mf with the given // memory usage kind and invoking readAt to store data into memory. (If readAt // returns a successful partial read, Fill will call it repeatedly until all // bytes have been read.) EOF is handled consistently with the requirements of @@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map } // Drop removes segments for memmap.Mappable offsets in mr, freeing the -// corresponding platform.FileRanges. +// corresponding memmap.FileRanges. // // Preconditions: mr must be page-aligned. func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { @@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) { } // DropAll removes all segments in mr, freeing the corresponding -// platform.FileRanges. +// memmap.FileRanges. func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) { for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { mf.DecRef(seg.FileRange()) diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go index dd6f5aba6..a808894df 100644 --- a/pkg/sentry/fs/fsutil/frame_ref_set.go +++ b/pkg/sentry/fs/fsutil/frame_ref_set.go @@ -17,7 +17,7 @@ package fsutil import ( "math" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" ) @@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) { } // Merge implements segment.Functions.Merge. -func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) { +func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) { if val1 != val2 { return 0, false } @@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform. } // Split implements segment.Functions.Split. -func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) { +func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) { return val, val } // IncRefAndAccount adds a reference on the range fr. All newly inserted segments // are accounted as host page cache memory mappings. -func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) { seg, gap := refs.Find(fr.Start) for { switch { @@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) { // DecRefAndAccount removes a reference on the range fr and untracks segments // that are removed from memory accounting. -func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) { +func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) { seg := refs.FindSegment(fr.Start) for seg.Ok() && seg.Start() < fr.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index e82afd112..ef0113b52 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { // offsets in fr or until the next call to UnmapAll. // // Preconditions: The caller must hold a reference on all offsets in fr. -func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) { +func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) { chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift) f.mapsMu.Lock() defer f.mapsMu.Unlock() @@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) } // Preconditions: f.mapsMu must be locked. -func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error { +func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error { prot := syscall.PROT_READ if write { prot |= syscall.PROT_WRITE diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go index 78fec553e..c15d8a946 100644 --- a/pkg/sentry/fs/fsutil/host_mappable.go +++ b/pkg/sentry/fs/fsutil/host_mappable.go @@ -21,18 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// HostMappable implements memmap.Mappable and platform.File over a +// HostMappable implements memmap.Mappable and memmap.File over a // CachedFileObject. // // Lock order (compare the lock order model in mm/mm.go): // truncateMu ("fs locks") // mu ("memmap.Mappable locks not taken by Translate") -// ("platform.File locks") +// ("memmap.File locks") // backingFile ("CachedFileObject locks") // // +stateify savable @@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error { return nil } -// MapInternal implements platform.File.MapInternal. -func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (h *HostMappable) FD() int { return h.backingFile.FD() } -// IncRef implements platform.File.IncRef. -func (h *HostMappable) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (h *HostMappable) IncRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.IncRefOn(mr) } -// DecRef implements platform.File.DecRef. -func (h *HostMappable) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (h *HostMappable) DecRef(fr memmap.FileRange) { mr := memmap.MappableRange{Start: fr.Start, End: fr.End} h.hostFileMapper.DecRefOn(mr) } diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index 800c8b4e1..fe8b0b6ac 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -26,7 +26,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. c.mapsMu.Lock() @@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable } } -// IncRef implements platform.File.IncRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// IncRef implements memmap.File.IncRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { +func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg, gap := c.refs.Find(fr.Start) @@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) { } } -// DecRef implements platform.File.DecRef. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// DecRef implements memmap.File.DecRef. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. -func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { +func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) { // Hot path. Avoid defers. c.dataMu.Lock() seg := c.refs.FindSegment(fr.Start) @@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) { c.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. This is used when we +// MapInternal implements memmap.File.MapInternal. This is used when we // directly map an underlying host fd and CachingInodeOperations is used as the -// platform.File during translation. -func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// memmap.File during translation. +func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write) } -// FD implements platform.File.FD. This is used when we directly map an -// underlying host fd and CachingInodeOperations is used as the platform.File +// FD implements memmap.File.FD. This is used when we directly map an +// underlying host fd and CachingInodeOperations is used as the memmap.File // during translation. func (c *CachingInodeOperations) FD() int { return c.backingFile.FD() diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index aabce6cc9..d41d23a43 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/fdnotifier", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", diff --git a/pkg/sentry/fs/host/socket_iovec.go b/pkg/sentry/fs/host/socket_iovec.go index 5c18dbd5e..905afb50d 100644 --- a/pkg/sentry/fs/host/socket_iovec.go +++ b/pkg/sentry/fs/host/socket_iovec.go @@ -17,15 +17,12 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) // LINT.IfChange -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -76,7 +73,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 69879498a..1081fff52 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -67,8 +67,8 @@ func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vf } // Stat implements kernfs.Inode.Stat. -func (mi *masterInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := mi.InodeAttrs.Stat(vfsfs, opts) +func (mi *masterInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := mi.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -186,7 +186,7 @@ func (mfd *masterFileDescription) SetStat(ctx context.Context, opts vfs.SetStatO // Stat implements vfs.FileDescriptionImpl.Stat. func (mfd *masterFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := mfd.vfsfd.VirtualDentry().Mount().Filesystem() - return mfd.inode.Stat(fs, opts) + return mfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/slave.go index cf1a0f0ac..a91cae3ef 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/slave.go @@ -73,8 +73,8 @@ func (si *slaveInode) Valid(context.Context) bool { } // Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - statx, err := si.InodeAttrs.Stat(vfsfs, opts) +func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } @@ -132,7 +132,7 @@ func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequen return sfd.inode.t.ld.outputQueueWrite(ctx, src) } -// Ioctl implements vfs.FileDescripionImpl.Ioctl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ @@ -183,7 +183,7 @@ func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOp // Stat implements vfs.FileDescriptionImpl.Stat. func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() - return sfd.inode.Stat(fs, opts) + return sfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index 142ee53b0..d0e06cdc0 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -136,6 +136,8 @@ func (a *Accessor) pathOperationAt(pathname string) *vfs.PathOperation { // CreateDeviceFile creates a device special file at the given pathname in the // devtmpfs instance accessed by the Accessor. func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind vfs.DeviceKind, major, minor uint32, perms uint16) error { + actx := a.wrapContext(ctx) + mode := (linux.FileMode)(perms) switch kind { case vfs.BlockDevice: @@ -145,12 +147,24 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v default: panic(fmt.Sprintf("invalid vfs.DeviceKind: %v", kind)) } + + // Create any parent directories. See + // devtmpfs.c:handle_create()=>path_create(). + for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() { + pop := a.pathOperationAt(it.String()) + if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{ + Mode: 0755, + }); err != nil { + return fmt.Errorf("failed to create directory %q: %v", it.String(), err) + } + } + // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't // create, which it recognizes by storing a pointer to the kdevtmpfs struct // thread in struct inode::i_private. Accessor doesn't yet support deletion // of files at all, and probably won't as long as we don't need to support // kernel modules, so this is moot for now. - return a.vfsObj.MknodAt(a.wrapContext(ctx), a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{ + return a.vfsObj.MknodAt(actx, a.creds, a.pathOperationAt(pathname), &vfs.MknodOptions{ Mode: mode, DevMajor: major, DevMinor: minor, diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index ef24f8159..abc610ef3 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -96,7 +96,7 @@ go_test( "//pkg/syserror", "//pkg/test/testutil", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/ext/dentry.go b/pkg/sentry/fsimpl/ext/dentry.go index 6bd1a9fc6..55902322a 100644 --- a/pkg/sentry/fsimpl/ext/dentry.go +++ b/pkg/sentry/fsimpl/ext/dentry.go @@ -63,12 +63,17 @@ func (d *dentry) DecRef() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +// TODO(b/134676337): Implement inotify. +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. // -// TODO(gvisor.dev/issue/1479): Implement inotify. +// TODO(b/134676337): Implement inotify. func (d *dentry) Watches() *vfs.Watches { return nil } + +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +// +// TODO(b/134676337): Implement inotify. +func (d *dentry) OnZeroWatches() {} diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD new file mode 100644 index 000000000..999111deb --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -0,0 +1,63 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "request_list", + out = "request_list.go", + package = "fuse", + prefix = "request", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Request", + "Linker": "*Request", + }, +) + +go_library( + name = "fuse", + srcs = [ + "connection.go", + "dev.go", + "fusefs.go", + "init.go", + "register.go", + "request_list.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/log", + "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +go_test( + name = "fuse_test", + size = "small", + srcs = ["dev_test.go"], + library = ":fuse", + deps = [ + "//pkg/abi/linux", + "//pkg/sentry/fsimpl/testutil", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", + "//pkg/syserror", + "//pkg/usermem", + "//pkg/waiter", + "//tools/go_marshal/marshal", + ], +) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go new file mode 100644 index 000000000..6df2728ab --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -0,0 +1,437 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// maxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const maxActiveRequestsDefault = 10000 + +// Ordinary requests have even IDs, while interrupts IDs are odd. +// Used to increment the unique ID for each FUSE request. +var reqIDStep uint64 = 2 + +const ( + // fuseDefaultMaxBackground is the default value for MaxBackground. + fuseDefaultMaxBackground = 12 + + // fuseDefaultCongestionThreshold is the default value for CongestionThreshold, + // and is 75% of the default maximum of MaxGround. + fuseDefaultCongestionThreshold = (fuseDefaultMaxBackground * 3 / 4) + + // fuseDefaultMaxPagesPerReq is the default value for MaxPagesPerReq. + fuseDefaultMaxPagesPerReq = 32 +) + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// connection is the struct by which the sentry communicates with the FUSE server daemon. +type connection struct { + fd *DeviceFD + + // The following FUSE_INIT flags are currently unsupported by this implementation: + // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC) + // - FUSE_EXPORT_SUPPORT + // - FUSE_HANDLE_KILLPRIV + // - FUSE_POSIX_LOCKS: requires POSIX locks + // - FUSE_FLOCK_LOCKS: requires POSIX locks + // - FUSE_AUTO_INVAL_DATA: requires page caching eviction + // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction + // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation + // - FUSE_ASYNC_DIO + // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler + + // initialized after receiving FUSE_INIT reply. + // Until it's set, suspend sending FUSE requests. + // Use SetInitialized() and IsInitialized() for atomic access. + initialized int32 + + // initializedChan is used to block requests before initialization. + initializedChan chan struct{} + + // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground). + // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block. + blocked bool + + // connected (connection established) when a new FUSE file system is created. + // Set to false when: + // umount, + // connection abort, + // device release. + connected bool + + // aborted via sysfs. + // TODO(gvisor.dev/issue/3185): abort all queued requests. + aborted bool + + // connInitError if FUSE_INIT encountered error (major version mismatch). + // Only set in INIT. + connInitError bool + + // connInitSuccess if FUSE_INIT is successful. + // Only set in INIT. + // Used for destory. + connInitSuccess bool + + // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress. + + // NumberBackground is the number of requests in the background. + numBackground uint16 + + // congestionThreshold for NumBackground. + // Negotiated in FUSE_INIT. + congestionThreshold uint16 + + // maxBackground is the maximum number of NumBackground. + // Block connection when it is reached. + // Negotiated in FUSE_INIT. + maxBackground uint16 + + // numActiveBackground is the number of requests in background and has being marked as active. + numActiveBackground uint16 + + // numWating is the number of requests waiting for completion. + numWaiting uint32 + + // TODO(gvisor.dev/issue/3185): BgQueue + // some queue for background queued requests. + + // bgLock protects: + // MaxBackground, CongestionThreshold, NumBackground, + // NumActiveBackground, BgQueue, Blocked. + bgLock sync.Mutex + + // maxRead is the maximum size of a read buffer in in bytes. + maxRead uint32 + + // maxWrite is the maximum size of a write buffer in bytes. + // Negotiated in FUSE_INIT. + maxWrite uint32 + + // maxPages is the maximum number of pages for a single request to use. + // Negotiated in FUSE_INIT. + maxPages uint16 + + // minor version of the FUSE protocol. + // Negotiated and only set in INIT. + minor uint32 + + // asyncRead if read pages asynchronously. + // Negotiated and only set in INIT. + asyncRead bool + + // abortErr is true if kernel need to return an unique read error after abort. + // Negotiated and only set in INIT. + abortErr bool + + // writebackCache is true for write-back cache policy, + // false for write-through policy. + // Negotiated and only set in INIT. + writebackCache bool + + // cacheSymlinks if filesystem needs to cache READLINK responses in page cache. + // Negotiated and only set in INIT. + cacheSymlinks bool + + // bigWrites if doing multi-page cached writes. + // Negotiated and only set in INIT. + bigWrites bool + + // dontMask if filestestem does not apply umask to creation modes. + // Negotiated in INIT. + dontMask bool +} + +// newFUSEConnection creates a FUSE connection to fd. +func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) { + // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to + // mount a FUSE filesystem. + fuseFD := fd.Impl().(*DeviceFD) + fuseFD.mounted = true + + // Create the writeBuf for the header to be stored in. + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + fuseFD.writeBuf = make([]byte, hdrLen) + fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) + fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests) + fuseFD.writeCursor = 0 + + return &connection{ + fd: fuseFD, + maxBackground: fuseDefaultMaxBackground, + congestionThreshold: fuseDefaultCongestionThreshold, + maxPages: fuseDefaultMaxPagesPerReq, + initializedChan: make(chan struct{}), + connected: true, + }, nil +} + +// SetInitialized atomically sets the connection as initialized. +func (conn *connection) SetInitialized() { + // Unblock the requests sent before INIT. + close(conn.initializedChan) + + // Close the channel first to avoid the non-atomic situation + // where conn.initialized is true but there are + // tasks being blocked on the channel. + // And it prevents the newer tasks from gaining + // unnecessary higher chance to be issued before the blocked one. + + atomic.StoreInt32(&(conn.initialized), int32(1)) +} + +// IsInitialized atomically check if the connection is initialized. +// pairs with SetInitialized(). +func (conn *connection) Initialized() bool { + return atomic.LoadInt32(&(conn.initialized)) != 0 +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + hdr.MarshalUnsafe(buf[:hdrLen]) + payload.MarshalUnsafe(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// Call makes a request to the server and blocks the invoking task until a +// server responds with a response. Task should never be nil. +// Requests will not be sent before the connection is initialized. +// For async tasks, use CallAsync(). +func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { + // Block requests sent before connection is initalized. + if !conn.Initialized() { + if err := t.Block(conn.initializedChan); err != nil { + return nil, err + } + } + + return conn.call(t, r) +} + +// CallAsync makes an async (aka background) request. +// Those requests either do not expect a response (e.g. release) or +// the response should be handled by others (e.g. init). +// Return immediately unless the connection is blocked (before initialization). +// Async call example: init, release, forget, aio, interrupt. +// When the Request is FUSE_INIT, it will not be blocked before initialization. +func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { + // Block requests sent before connection is initalized. + if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT { + if err := t.Block(conn.initializedChan); err != nil { + return err + } + } + + // This should be the only place that invokes call() with a nil task. + _, err := conn.call(nil, r) + return err +} + +// call makes a call without blocking checks. +func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) { + if !conn.connected { + return nil, syserror.ENOTCONN + } + + if conn.connInitError { + return nil, syserror.ECONNREFUSED + } + + fut, err := conn.callFuture(t, r) + if err != nil { + return nil, err + } + + return fut.resolve(t) +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + m.UnmarshalUnsafe(r.data[hdrLen:]) + return nil +} + +// callFuture makes a request to the server and returns a future response. +// Call resolve() when the response needs to be fulfilled. +func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + + // Is the queue full? + // + // We must busy wait here until the request can be queued. We don't + // block on the fd.fullQueueCh with a lock - so after being signalled, + // before we acquire the lock, it is possible that a barging task enters + // and queues a request. As a result, upon acquiring the lock we must + // again check if the room is available. + // + // This can potentially starve a request forever but this can only happen + // if there are always too many ongoing requests all the time. The + // supported maxActiveRequests setting should be really high to avoid this. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + if t == nil { + // Since there is no task that is waiting. We must error out. + return nil, errors.New("FUSE request queue full") + } + + log.Infof("Blocking request %v from being queued. Too many active requests: %v", + r.id, conn.fd.numActiveRequests) + conn.fd.mu.Unlock() + err := t.Block(conn.fd.fullQueueCh) + conn.fd.mu.Lock() + if err != nil { + return nil, err + } + } + + return conn.callFutureLocked(t, r) +} + +// callFutureLocked makes a request to the server and returns a future response. +func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + conn.fd.queue.PushBack(r) + conn.fd.numActiveRequests += 1 + fut := newFutureResponse(r.hdr.Opcode) + conn.fd.completions[r.id] = fut + + // Signal the readers that there is something to read. + conn.fd.waitQueue.Notify(waiter.EventIn) + + return fut, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse { + return &futureResponse{ + opcode: opcode, + ch: make(chan struct{}), + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // If there is no Task associated with this request - then we don't try to resolve + // the response. Instead, the task writing the response (proxy to the server) will + // process the response on our behalf. + if t == nil { + log.Infof("fuse.Response.resolve: Not waiting on a response from server.") + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go new file mode 100644 index 000000000..2225076bc --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -0,0 +1,397 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +const fuseDevMinor = 229 + +// fuseDevice implements vfs.Device for /dev/fuse. +type fuseDevice struct{} + +// Open implements vfs.Device.Open. +func (fuseDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + if !kernel.FUSEEnabled { + return nil, syserror.ENOENT + } + + var fd DeviceFD + if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{ + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// DeviceFD implements vfs.FileDescriptionImpl for /dev/fuse. +type DeviceFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. + mounted bool + + // nextOpID is used to create new requests. + nextOpID linux.FUSEOpID + + // queue is the list of requests that need to be processed by the FUSE server. + queue requestList + + // numActiveRequests is the number of requests made by the Sentry that has + // yet to be responded to. + numActiveRequests uint64 + + // completions is used to map a request to its response. A Writer will use this + // to notify the caller of a completed response. + completions map[linux.FUSEOpID]*futureResponse + + writeCursor uint32 + + // writeBuf is the memory buffer used to copy in the FUSE out header from + // userspace. + writeBuf []byte + + // writeCursorFR current FR being copied from server. + writeCursorFR *futureResponse + + // mu protects all the queues, maps, buffers and cursors and nextOpID. + mu sync.Mutex + + // waitQueue is used to notify interested parties when the device becomes + // readable or writable. + waitQueue waiter.Queue + + // fullQueueCh is a channel used to synchronize the readers with the writers. + // Writers (inbound requests to the filesystem) block if there are too many + // unprocessed in-flight requests. + fullQueueCh chan struct{} + + // fs is the FUSE filesystem that this FD is being used for. + fs *filesystem +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *DeviceFD) Release() { + fd.fs.conn.connected = false +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + // We require that any Read done on this filesystem have a sane minimum + // read buffer. It must have the capacity for the fixed parts of any request + // header (Linux uses the request header and the FUSEWriteIn header for this + // calculation) + the negotiated MaxWrite room for the data. + minBuffSize := linux.FUSE_MIN_READ_BUFFER + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes()) + negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.maxWrite + if minBuffSize < negotiatedMinBuffSize { + minBuffSize = negotiatedMinBuffSize + } + + // If the read buffer is too small, error out. + if dst.NumBytes() < int64(minBuffSize) { + return 0, syserror.EINVAL + } + + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readLocked(ctx, dst, opts) +} + +// readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + if fd.queue.Empty() { + return 0, syserror.ErrWouldBlock + } + + var readCursor uint32 + var bytesRead int64 + for { + req := fd.queue.Front() + if dst.NumBytes() < int64(req.hdr.Len) { + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } + + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req); err != nil { + return 0, err + } + + // We're done with this request. + fd.queue.Remove(req) + + // Restart the read as this request was invalid. + log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.") + return fd.readLocked(ctx, dst, opts) + } + + n, err := dst.CopyOut(ctx, req.data[readCursor:]) + if err != nil { + return 0, err + } + readCursor += uint32(n) + bytesRead += int64(n) + + if readCursor >= req.hdr.Len { + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + break + } + } + + return bytesRead, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.writeLocked(ctx, src, opts) +} + +// writeLocked implements writing to the fuse device while locked with DeviceFD.mu. +func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + var cn, n int64 + hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + + for src.NumBytes() > 0 { + if fd.writeCursorFR != nil { + // Already have common header, and we're now copying the payload. + wantBytes := fd.writeCursorFR.hdr.Len + + // Note that the FR data doesn't have the header. Copy it over if its necessary. + if fd.writeCursorFR.data == nil { + fd.writeCursorFR.data = make([]byte, wantBytes) + } + + bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == wantBytes { + // Done reading this full response. Clean up and unblock the + // initiator. + break + } + + // Check if we have more data in src. + continue + } + + // Assert that the header isn't read into the writeBuf yet. + if fd.writeCursor >= hdrLen { + return 0, syserror.EINVAL + } + + // We don't have the full common response header yet. + wantBytes := hdrLen - fd.writeCursor + bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes]) + if err != nil { + return 0, err + } + src = src.DropFirst(bytesCopied) + + cn = int64(bytesCopied) + n += cn + fd.writeCursor += uint32(cn) + if fd.writeCursor == hdrLen { + // Have full header in the writeBuf. Use it to fetch the actual futureResponse + // from the device's completions map. + var hdr linux.FUSEHeaderOut + hdr.UnmarshalBytes(fd.writeBuf) + + // We have the header now and so the writeBuf has served its purpose. + // We could reset it manually here but instead of doing that, at the + // end of the write, the writeCursor will be set to 0 thereby allowing + // the next request to overwrite whats in the buffer, + + fut, ok := fd.completions[hdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return 0, syserror.EINVAL + } + + delete(fd.completions, hdr.Unique) + + // Copy over the header into the future response. The rest of the payload + // will be copied over to the FR's data in the next iteration. + fut.hdr = &hdr + fd.writeCursorFR = fut + + // Next iteration will now try read the complete request, if src has + // any data remaining. Otherwise we're done. + } + } + + if fd.writeCursorFR != nil { + if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil { + return 0, err + } + + // Ready the device for the next request. + fd.writeCursorFR = nil + fd.writeCursor = 0 + } + + return n, nil +} + +// Readiness implements vfs.FileDescriptionImpl.Readiness. +func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + var ready waiter.EventMask + ready |= waiter.EventOut // FD is always writable + if !fd.queue.Empty() { + // Have reqs available, FD is readable. + ready |= waiter.EventIn + } + + return ready & mask +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.waitQueue.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { + fd.waitQueue.EventUnregister(e) +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. + if !fd.mounted { + return 0, syserror.EPERM + } + + return 0, syserror.ENOSYS +} + +// sendResponse sends a response to the waiting task (if any). +func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { + // See if the running task need to perform some action before returning. + // Since we just finished writing the future, we can be sure that + // getResponse generates a populated response. + if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil { + return err + } + + // Signal that the queue is no longer full. + select { + case fd.fullQueueCh <- struct{}{}: + default: + } + fd.numActiveRequests -= 1 + + // Signal the task waiting on a response. + close(fut.ch) + return nil +} + +// sendError sends an error response to the waiting task (if any). +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error { + // Return the error to the calling task. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + respHdr := linux.FUSEHeaderOut{ + Len: outHdrLen, + Error: errno, + Unique: req.hdr.Unique, + } + + fut, ok := fd.completions[respHdr.Unique] + if !ok { + // Server sent us a response for a request we never sent? + return syserror.EINVAL + } + delete(fd.completions, respHdr.Unique) + + fut.hdr = &respHdr + if err := fd.sendResponse(ctx, fut); err != nil { + return err + } + + return nil +} + +// noReceiverAction has the calling kernel.Task do some action if its known that no +// receiver is going to be waiting on the future channel. This is to be used by: +// FUSE_INIT. +func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error { + if r.opcode == linux.FUSE_INIT { + creds := auth.CredentialsFromContext(ctx) + rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace() + return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs)) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go new file mode 100644 index 000000000..84c222ad6 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -0,0 +1,428 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "io" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// echoTestOpcode is the Opcode used during testing. The server used in tests +// will simply echo the payload back with the appropriate headers. +const echoTestOpcode linux.FUSEOpcode = 1000 + +type testPayload struct { + data uint32 +} + +// TestFUSECommunication tests that the communication layer between the Sentry and the +// FUSE server daemon works as expected. +func TestFUSECommunication(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + + // Create test cases with different number of concurrent clients and servers. + testCases := []struct { + Name string + NumClients int + NumServers int + MaxActiveRequests uint64 + }{ + { + Name: "SingleClientSingleServer", + NumClients: 1, + NumServers: 1, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "SingleClientMultipleServers", + NumClients: 1, + NumServers: 10, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "MultipleClientsSingleServer", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "MultipleClientsMultipleServers", + NumClients: 10, + NumServers: 10, + MaxActiveRequests: maxActiveRequestsDefault, + }, + { + Name: "RequestCapacityFull", + NumClients: 10, + NumServers: 1, + MaxActiveRequests: 1, + }, + { + Name: "RequestCapacityContinuouslyFull", + NumClients: 100, + NumServers: 2, + MaxActiveRequests: 2, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Name, func(t *testing.T) { + conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + clientsDone := make([]chan struct{}, testCase.NumClients) + serversDone := make([]chan struct{}, testCase.NumServers) + serversKill := make([]chan struct{}, testCase.NumServers) + + // FUSE clients. + for i := 0; i < testCase.NumClients; i++ { + clientsDone[i] = make(chan struct{}) + go func(i int) { + fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i]) + }(i) + } + + // FUSE servers. + for j := 0; j < testCase.NumServers; j++ { + serversDone[j] = make(chan struct{}) + serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block. + go func(j int) { + fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j]) + }(j) + } + + // Tear down. + // + // Make sure all the clients are done. + for i := 0; i < testCase.NumClients; i++ { + <-clientsDone[i] + } + + // Kill any server that is potentially waiting. + for j := 0; j < testCase.NumServers; j++ { + serversKill[j] <- struct{}{} + } + + // Make sure all the servers are done. + for j := 0; j < testCase.NumServers; j++ { + <-serversDone[j] + } + }) + } +} + +// CallTest makes a request to the server and blocks the invoking +// goroutine until a server responds with a response. Doesn't block +// a kernel.Task. Analogous to Connection.Call but used for testing. +func CallTest(conn *connection, t *kernel.Task, r *Request, i uint32) (*Response, error) { + conn.fd.mu.Lock() + + // Wait until we're certain that a new request can be processed. + for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { + conn.fd.mu.Unlock() + select { + case <-conn.fd.fullQueueCh: + } + conn.fd.mu.Lock() + } + + fut, err := conn.callFutureLocked(t, r) // No task given. + conn.fd.mu.Unlock() + + if err != nil { + return nil, err + } + + // Resolve the response. + // + // Block without a task. + select { + case <-fut.ch: + } + + // A response is ready. Resolve and return it. + return fut.getResponse(), nil +} + +// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE +// device. However, it does so by - not blocking the task that is calling - and +// instead just waits on a channel. The behaviour is essentially the same as +// DeviceFD.Read except it guarantees that the task is not blocked. +func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) { + var err error + var n, total int64 + + dev := fd.Impl().(*DeviceFD) + + // Register for notifications. + w, ch := waiter.NewChannelEntry(nil) + dev.EventRegister(&w, waiter.EventIn) + for { + // Issue the request and break out if it completes with anything other than + // "would block". + n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{}) + total += n + if err != syserror.ErrWouldBlock { + break + } + + // Wait for a notification that we should retry. + // Emulate the blocking for when no requests are available + select { + case <-ch: + case <-killServer: + // Server killed by the main program. + return 0, true, nil + } + } + + dev.EventUnregister(&w) + return total, false, err +} + +// fuseClientRun emulates all the actions of a normal FUSE request. It creates +// a header, a payload, calls the server, waits for the response, and processes +// the response. +func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) { + defer func() { clientDone <- struct{}{} }() + + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + testObj := &testPayload{ + data: rand.Uint32(), + } + + req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + + // Queue up a request. + // Analogous to Call except it doesn't block on the task. + resp, err := CallTest(conn, clientTask, req, pid) + if err != nil { + t.Fatalf("CallTaskNonBlock failed: %v", err) + } + + if err = resp.Error(); err != nil { + t.Fatalf("Server responded with an error: %v", err) + } + + var respTestPayload testPayload + if err := resp.UnmarshalPayload(&respTestPayload); err != nil { + t.Fatalf("Unmarshalling payload error: %v", err) + } + + if resp.hdr.Unique != req.hdr.Unique { + t.Fatalf("got response for another request. Expected response for req %v but got response for req %v", + req.hdr.Unique, resp.hdr.Unique) + } + + if respTestPayload.data != testObj.data { + t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + } + +} + +// fuseServerRun creates a task and emulates all the actions of a simple FUSE server +// that simply reads a request and echos the same struct back as a response using the +// appropriate headers. +func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) { + defer func() { serverDone <- struct{}{} }() + + // Create the tasks that the server will be using. + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + var readPayload testPayload + + serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) + if err != nil { + t.Fatal(err) + } + + // Read the request. + for { + inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes()) + payloadLen := uint32(readPayload.SizeBytes()) + + // The raed buffer must meet some certain size criteria. + buffSize := inHdrLen + payloadLen + if buffSize < linux.FUSE_MIN_READ_BUFFER { + buffSize = linux.FUSE_MIN_READ_BUFFER + } + inBuf := make([]byte, buffSize) + inIOseq := usermem.BytesIOSequence(inBuf) + + n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer) + if err != nil { + t.Fatalf("Read failed :%v", err) + } + + // Server should shut down. No new requests are going to be made. + if serverKilled { + break + } + + if n <= 0 { + t.Fatalf("Read read no bytes") + } + + var readFUSEHeaderIn linux.FUSEHeaderIn + readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) + readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + + if readFUSEHeaderIn.Opcode != echoTestOpcode { + t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) + } + + // Write the response. + outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) + outBuf := make([]byte, outHdrLen+payloadLen) + outHeader := linux.FUSEHeaderOut{ + Len: outHdrLen + payloadLen, + Error: 0, + Unique: readFUSEHeaderIn.Unique, + } + + // Echo the payload back. + outHeader.MarshalUnsafe(outBuf[:outHdrLen]) + readPayload.MarshalUnsafe(outBuf[outHdrLen:]) + outIOseq := usermem.BytesIOSequence(outBuf) + + n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed :%v", err) + } + } +} + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef() + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go new file mode 100644 index 000000000..200a93bbf --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -0,0 +1,228 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fuse implements fusefs. +package fuse + +import ( + "strconv" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// Name is the default filesystem name. +const Name = "fuse" + +// FilesystemType implements vfs.FilesystemType. +type FilesystemType struct{} + +type filesystemOptions struct { + // userID specifies the numeric uid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + userID uint32 + + // groupID specifies the numeric gid of the mount owner. + // This option should not be specified by the filesystem owner. + // It is set by libfuse (or, if libfuse is not used, must be set + // by the filesystem itself). For more information, see man page + // for fuse(8) + groupID uint32 + + // rootMode specifies the the file mode of the filesystem's root. + rootMode linux.FileMode + + // maxActiveRequests specifies the maximum number of active requests that can + // exist at any time. Any further requests will block when trying to + // Call the server. + maxActiveRequests uint64 +} + +// filesystem implements vfs.FilesystemImpl. +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // conn is used for communication between the FUSE server + // daemon and the sentry fusefs. + conn *connection + + // opts is the options the fusefs is initialized with. + opts *filesystemOptions +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + var fsopts filesystemOptions + mopts := vfs.GenericParseMountOptions(opts.Data) + deviceDescriptorStr, ok := mopts["fd"] + if !ok { + log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + return nil, nil, syserror.EINVAL + } + delete(mopts, "fd") + + deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) + if err != nil { + return nil, nil, err + } + + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("%s.GetFilesystem: couldn't get kernel task from context", fsType.Name()) + return nil, nil, syserror.EINVAL + } + fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + + // Parse and set all the other supported FUSE mount options. + // TODO(gVisor.dev/issue/3229): Expand the supported mount options. + if userIDStr, ok := mopts["user_id"]; ok { + delete(mopts, "user_id") + userID, err := strconv.ParseUint(userIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.userID = uint32(userID) + } + + if groupIDStr, ok := mopts["group_id"]; ok { + delete(mopts, "group_id") + groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + return nil, nil, syserror.EINVAL + } + fsopts.groupID = uint32(groupID) + } + + rootMode := linux.FileMode(0777) + modeStr, ok := mopts["rootmode"] + if ok { + delete(mopts, "rootmode") + mode, err := strconv.ParseUint(modeStr, 8, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) + return nil, nil, syserror.EINVAL + } + rootMode = linux.FileMode(mode) + } + fsopts.rootMode = rootMode + + // Set the maxInFlightRequests option. + fsopts.maxActiveRequests = maxActiveRequestsDefault + + // Check for unparsed options. + if len(mopts) != 0 { + log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) + return nil, nil, syserror.EINVAL + } + + // Create a new FUSE filesystem. + fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + if err != nil { + log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) + return nil, nil, err + } + + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + // Send a FUSE_INIT request to the FUSE daemon server before returning. + // This call is not blocking. + if err := fs.conn.InitSend(creds, uint32(kernelTask.ThreadID())); err != nil { + log.Warningf("%s.InitSend: failed with error: %v", fsType.Name(), err) + return nil, nil, err + } + + // root is the fusefs root directory. + root := fs.newInode(creds, fsopts.rootMode) + + return fs.VFSFilesystem(), root.VFSDentry(), nil +} + +// NewFUSEFilesystem creates a new FUSE filesystem. +func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + } + + conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests) + if err != nil { + log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) + return nil, syserror.EINVAL + } + + fs.conn = conn + fuseFD := device.Impl().(*DeviceFD) + fuseFD.fs = fs + + return fs, nil +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release() { + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release() +} + +// inode implements kernfs.Inode. +type inode struct { + kernfs.InodeAttrs + kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren + kernfs.OrderedChildren + + locks vfs.FileLocks + + dentry kernfs.Dentry +} + +func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &inode{} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.dentry.Init(i) + + return &i.dentry +} + +// Open implements kernfs.Inode.Open. +func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/init.go new file mode 100644 index 000000000..779c2bd3f --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/init.go @@ -0,0 +1,166 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// consts used by FUSE_INIT negotiation. +const ( + // fuseMaxMaxPages is the maximum value for MaxPages received in InitOut. + // Follow the same behavior as unix fuse implementation. + fuseMaxMaxPages = 256 + + // Maximum value for the time granularity for file time stamps, 1s. + // Follow the same behavior as unix fuse implementation. + fuseMaxTimeGranNs = 1000000000 + + // Minimum value for MaxWrite. + // Follow the same behavior as unix fuse implementation. + fuseMinMaxWrite = 4096 + + // Temporary default value for max readahead, 128kb. + fuseDefaultMaxReadahead = 131072 + + // The FUSE_INIT_IN flags sent to the daemon. + // TODO(gvisor.dev/issue/3199): complete the flags. + fuseDefaultInitFlags = linux.FUSE_MAX_PAGES +) + +// Adjustable maximums for Connection's cogestion control parameters. +// Used as the upperbound of the config values. +// Currently we do not support adjustment to them. +var ( + MaxUserBackgroundRequest uint16 = fuseDefaultMaxBackground + MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold +) + +// InitSend sends a FUSE_INIT request. +func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { + in := linux.FUSEInitIn{ + Major: linux.FUSE_KERNEL_VERSION, + Minor: linux.FUSE_KERNEL_MINOR_VERSION, + // TODO(gvisor.dev/issue/3196): find appropriate way to calculate this + MaxReadahead: fuseDefaultMaxReadahead, + Flags: fuseDefaultInitFlags, + } + + req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) + if err != nil { + return err + } + + // Since there is no task to block on and FUSE_INIT is the request + // to unblock other requests, use nil. + return conn.CallAsync(nil, req) +} + +// InitRecv receives a FUSE_INIT reply and process it. +func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error { + if err := res.Error(); err != nil { + return err + } + + var out linux.FUSEInitOut + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + return conn.initProcessReply(&out, hasSysAdminCap) +} + +// Process the FUSE_INIT reply from the FUSE server. +func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error { + // No support for old major fuse versions. + if out.Major != linux.FUSE_KERNEL_VERSION { + conn.connInitError = true + + // Set the connection as initialized and unblock the blocked requests + // (i.e. return error for them). + conn.SetInitialized() + + return nil + } + + // Start processing the reply. + conn.connInitSuccess = true + conn.minor = out.Minor + + // No support for limits before minor version 13. + if out.Minor >= 13 { + conn.bgLock.Lock() + + if out.MaxBackground > 0 { + conn.maxBackground = out.MaxBackground + + if !hasSysAdminCap && + conn.maxBackground > MaxUserBackgroundRequest { + conn.maxBackground = MaxUserBackgroundRequest + } + } + + if out.CongestionThreshold > 0 { + conn.congestionThreshold = out.CongestionThreshold + + if !hasSysAdminCap && + conn.congestionThreshold > MaxUserCongestionThreshold { + conn.congestionThreshold = MaxUserCongestionThreshold + } + } + + conn.bgLock.Unlock() + } + + // No support for the following flags before minor version 6. + if out.Minor >= 6 { + conn.asyncRead = out.Flags&linux.FUSE_ASYNC_READ != 0 + conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0 + conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0 + conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0 + conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0 + conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0 + + // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs). + + if out.Flags&linux.FUSE_MAX_PAGES != 0 { + maxPages := out.MaxPages + if maxPages < 1 { + maxPages = 1 + } + if maxPages > fuseMaxMaxPages { + maxPages = fuseMaxMaxPages + } + conn.maxPages = maxPages + } + } + + // No support for negotiating MaxWrite before minor version 5. + if out.Minor >= 5 { + conn.maxWrite = out.MaxWrite + } else { + conn.maxWrite = fuseMinMaxWrite + } + if conn.maxWrite < fuseMinMaxWrite { + conn.maxWrite = fuseMinMaxWrite + } + + // Set connection as initialized and unblock the requests + // issued before init. + conn.SetInitialized() + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go new file mode 100644 index 000000000..b5b581152 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/register.go @@ -0,0 +1,42 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" +) + +// Register registers the FUSE device with vfsObj. +func Register(vfsObj *vfs.VirtualFilesystem) error { + if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{ + GroupName: "misc", + }); err != nil { + return err + } + + return nil +} + +// CreateDevtmpfsFile creates a device special file in devtmpfs. +func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error { + if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil { + return err + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 480510b2a..8c7c8e1b3 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -85,6 +85,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { d2 := &dentry{ refs: 1, // held by d fs: d.fs, + ino: d.fs.nextSyntheticIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -138,6 +139,7 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fd.dirents = ds } + d.InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) if d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } @@ -183,13 +185,13 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { { Name: ".", Type: linux.DT_DIR, - Ino: d.ino, + Ino: uint64(d.ino), NextOff: 1, }, { Name: "..", Type: uint8(atomic.LoadUint32(&parent.mode) >> 12), - Ino: parent.ino, + Ino: uint64(parent.ino), NextOff: 2, }, } @@ -225,7 +227,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: p9d.QID.Path, + Ino: uint64(inoFromPath(p9d.QID.Path)), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or @@ -258,7 +260,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { dirents = append(dirents, vfs.Dirent{ Name: child.name, Type: uint8(atomic.LoadUint32(&child.mode) >> 12), - Ino: child.ino, + Ino: uint64(child.ino), NextOff: int64(len(dirents) + 1), }) } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 3c467e313..00e3c99cd 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -16,6 +16,7 @@ package gofer import ( "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -149,11 +150,9 @@ afterSymlink: return nil, err } if d != d.parent && !d.cachedMetadataAuthoritative() { - _, attrMask, attr, err := d.parent.file.getAttr(ctx, dentryAttrMask()) - if err != nil { + if err := d.parent.updateFromGetattr(ctx); err != nil { return nil, err } - d.parent.updateFromP9Attrs(attrMask, &attr) } rp.Advance() return d.parent, nil @@ -208,18 +207,28 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil // Preconditions: As for getChildLocked. !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { + if child != nil { + // Need to lock child.metadataMu because we might be updating child + // metadata. We need to hold the lock *before* getting metadata from the + // server and release it after updating local metadata. + child.metadataMu.Lock() + } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) if err != nil && err != syserror.ENOENT { + if child != nil { + child.metadataMu.Unlock() + } return nil, err } if child != nil { - if !file.isNil() && qid.Path == child.ino { - // The file at this path hasn't changed. Just update cached - // metadata. + if !file.isNil() && inoFromPath(qid.Path) == child.ino { + // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) - child.updateFromP9Attrs(attrMask, &attr) + child.updateFromP9AttrsLocked(attrMask, &attr) + child.metadataMu.Unlock() return child, nil } + child.metadataMu.Unlock() if file.isNil() && child.isSynthetic() { // We have a synthetic file, and no remote file has arisen to // replace it. @@ -371,17 +380,33 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } parent.touchCMtime() parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } if fs.opts.interop == InteropModeShared { - // The existence of a dentry at name would be inconclusive because the - // file it represents may have been deleted from the remote filesystem, - // so we would need to make an RPC to revalidate the dentry. Just - // attempt the file creation RPC instead. If a file does exist, the RPC - // will fail with EEXIST like we would have. If the RPC succeeds, and a - // stale dentry exists, the dentry will fail revalidation next time - // it's used. - return createInRemoteDir(parent, name) + if child := parent.children[name]; child != nil && child.isSynthetic() { + return syserror.EEXIST + } + // The existence of a non-synthetic dentry at name would be inconclusive + // because the file it represents may have been deleted from the remote + // filesystem, so we would need to make an RPC to revalidate the dentry. + // Just attempt the file creation RPC instead. If a file does exist, the + // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a + // stale dentry exists, the dentry will fail revalidation next time it's + // used. + if err := createInRemoteDir(parent, name); err != nil { + return err + } + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) + return nil } if child := parent.children[name]; child != nil { return syserror.EEXIST @@ -397,6 +422,11 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } parent.touchCMtime() parent.dirents = nil + ev := linux.IN_CREATE + if dir { + ev |= linux.IN_ISDIR + } + parent.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) return nil } @@ -443,21 +473,61 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b defer mntns.DecRef() parent.dirMu.Lock() defer parent.dirMu.Unlock() + child, ok := parent.children[name] if ok && child == nil { return syserror.ENOENT } - // We only need a dentry representing the file at name if it can be a mount - // point. If child is nil, then it can't be a mount point. If child is - // non-nil but stale, the actual file can't be a mount point either; we - // detect this case by just speculatively calling PrepareDeleteDentry and - // only revalidating the dentry if that fails (indicating that the existing - // dentry is a mount point). + + sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 + if sticky { + if !ok { + // If the sticky bit is set, we need to retrieve the child to determine + // whether removing it is allowed. + child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err + } + } else if child != nil && !child.cachedMetadataAuthoritative() { + // Make sure the dentry representing the file at name is up to date + // before examining its metadata. + child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) + if err != nil { + return err + } + } + if err := parent.mayDelete(rp.Credentials(), child); err != nil { + return err + } + } + + // If a child dentry exists, prepare to delete it. This should fail if it is + // a mount point. We detect mount points by speculatively calling + // PrepareDeleteDentry, which fails if child is a mount point. However, we + // may need to revalidate the file in this case to make sure that it has not + // been deleted or replaced on the remote fs, in which case the mount point + // will have disappeared. If calling PrepareDeleteDentry fails again on the + // up-to-date dentry, we can be sure that it is a mount point. + // + // Also note that if child is nil, then it can't be a mount point. if child != nil { + // Hold child.dirMu so we can check child.children and + // child.syntheticChildren. We don't access these fields until a bit later, + // but locking child.dirMu after calling vfs.PrepareDeleteDentry() would + // create an inconsistent lock ordering between dentry.dirMu and + // vfs.Dentry.mu (in the VFS lock order, it would make dentry.dirMu both "a + // FilesystemImpl lock" and "a lock acquired by a FilesystemImpl between + // PrepareDeleteDentry and CommitDeleteDentry). To avoid this, lock + // child.dirMu before calling PrepareDeleteDentry. child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - if parent.cachedMetadataAuthoritative() { + // We can skip revalidation in several cases: + // - We are not in InteropModeShared + // - The parent directory is synthetic, in which case the child must also + // be synthetic + // - We already updated the child during the sticky bit check above + if parent.cachedMetadataAuthoritative() || sticky { return err } child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) @@ -518,7 +588,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b if child == nil { return syserror.ENOENT } - } else { + } else if child == nil || !child.isSynthetic() { err = parent.file.unlinkAt(ctx, name, flags) if err != nil { if child != nil { @@ -527,6 +597,18 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return err } } + + // Generate inotify events for rmdir or unlink. + if dir { + parent.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) + } else { + var cw *vfs.Watches + if child != nil { + cw = &child.watches + } + vfs.InotifyRemoveChild(cw, &parent.watches, name) + } + if child != nil { vfsObj.CommitDeleteDentry(&child.vfsd) child.setDeleted() @@ -764,15 +846,17 @@ afterTrailingSymlink: parent.dirMu.Unlock() return fd, err } + parent.dirMu.Unlock() if err != nil { - parent.dirMu.Unlock() return nil, err } - // Open existing child or follow symlink. - parent.dirMu.Unlock() if mustCreate { return nil, syserror.EEXIST } + if !child.isDir() && rp.MustBeDir() { + return nil, syserror.ENOTDIR + } + // Open existing child or follow symlink. if child.isSymlink() && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { @@ -793,11 +877,22 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } + + trunc := opts.Flags&linux.O_TRUNC != 0 && d.fileType() == linux.S_IFREG + if trunc { + // Lock metadataMu *while* we open a regular file with O_TRUNC because + // open(2) will change the file size on server. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + } + + var vfd *vfs.FileDescription + var err error mnt := rp.Mount() switch d.fileType() { case linux.S_IFREG: if !d.fs.opts.regularFilesUseSpecialFileFD { - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, opts.Flags&linux.O_TRUNC != 0); err != nil { + if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil { return nil, err } fd := ®ularFileFD{} @@ -807,7 +902,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }); err != nil { return nil, err } - return &fd.vfsfd, nil + vfd = &fd.vfsfd } case linux.S_IFDIR: // Can't open directories with O_CREAT. @@ -847,7 +942,25 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return d.pipe.Open(ctx, mnt, &d.vfsd, opts.Flags, &d.locks) } } - return d.openSpecialFileLocked(ctx, mnt, opts) + + if vfd == nil { + if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil { + return nil, err + } + } + + if trunc { + // If no errors occured so far then update file size in memory. This + // step is required even if !d.cachedMetadataAuthoritative() because + // d.mappings has to be updated. + // d.metadataMu has already been acquired if trunc == true. + d.updateFileSizeLocked(0) + + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + } + return vfd, err } func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { @@ -1013,6 +1126,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } childVFSFD = &fd.vfsfd } + d.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) return childVFSFD, nil } @@ -1064,7 +1178,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return err } } - if err := oldParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + creds := rp.Credentials() + if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } vfsObj := rp.VirtualFilesystem() @@ -1079,12 +1194,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if renamed == nil { return syserror.ENOENT } + if err := oldParent.mayDelete(creds, renamed); err != nil { + return err + } if renamed.isDir() { if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { return syserror.EINVAL } if oldParent != newParent { - if err := renamed.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { return err } } @@ -1095,7 +1213,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } if oldParent != newParent { - if err := newParent.checkPermissions(rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { + if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } newParent.dirMu.Lock() @@ -1193,10 +1311,12 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.cachedMetadataAuthoritative() { newParent.dirents = nil newParent.touchCMtime() - if renamed.isDir() { + if renamed.isDir() && (replaced == nil || !replaced.isDir()) { + // Increase the link count if we did not replace another directory. newParent.incLinks() } } + vfs.InotifyRename(ctx, &renamed.watches, &oldParent.watches, &newParent.watches, oldName, newName, renamed.isDir()) return nil } @@ -1209,12 +1329,21 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.setStat(ctx, rp.Credentials(), &opts.Stat, rp.Mount()) + if err := d.setStat(ctx, rp.Credentials(), &opts, rp.Mount()); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + fs.renameMuRUnlockAndCheckCaching(&ds) + + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + d.InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // StatAt implements vfs.FilesystemImpl.StatAt. @@ -1338,24 +1467,38 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.setxattr(ctx, rp.Credentials(), &opts) + fs.renameMuRUnlockAndCheckCaching(&ds) + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(&ds) d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) + return err + } + if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { + fs.renameMuRUnlockAndCheckCaching(&ds) return err } - return d.removexattr(ctx, rp.Credentials(), name) + fs.renameMuRUnlockAndCheckCaching(&ds) + + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // PrependPath implements vfs.FilesystemImpl.PrependPath. @@ -1364,3 +1507,7 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +func (fs *filesystem) nextSyntheticIno() inodeNumber { + return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index d8ae475ed..e20de84b5 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -110,6 +110,26 @@ type filesystem struct { syncMu sync.Mutex syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} + + // syntheticSeq stores a counter to used to generate unique inodeNumber for + // synthetic dentries. + syntheticSeq uint64 +} + +// inodeNumber represents inode number reported in Dirent.Ino. For regular +// dentries, it comes from QID.Path from the 9P server. Synthetic dentries +// have have their inodeNumber generated sequentially, with the MSB reserved to +// prevent conflicts with regular dentries. +type inodeNumber uint64 + +// Reserve MSB for synthetic mounts. +const syntheticInoMask = uint64(1) << 63 + +func inoFromPath(path uint64) inodeNumber { + if path&syntheticInoMask != 0 { + log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) + } + return inodeNumber(path &^ syntheticInoMask) } type filesystemOptions struct { @@ -582,21 +602,27 @@ type dentry struct { // returned by the server. dirents is protected by dirMu. dirents []vfs.Dirent - // Cached metadata; protected by metadataMu and accessed using atomic - // memory operations unless otherwise specified. + // Cached metadata; protected by metadataMu. + // To access: + // - In situations where consistency is not required (like stat), these + // can be accessed using atomic operations only (without locking). + // - Lock metadataMu and can access without atomic operations. + // To mutate: + // - Lock metadataMu and use atomic operations to update because we might + // have atomic readers that don't hold the lock. metadataMu sync.Mutex - ino uint64 // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + ino inodeNumber // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 ctime int64 btime int64 // File size, protected by both metadataMu and dataMu (i.e. both must be - // locked to mutate it). + // locked to mutate it; locking either is sufficient to access it). size uint64 // nlink counts the number of hard links to this dentry. It's updated and @@ -665,6 +691,9 @@ type dentry struct { pipe *pipe.VFSPipe locks vfs.FileLocks + + // Inotify watches for this dentry. + watches vfs.Watches } // dentryAttrMask returns a p9.AttrMask enabling all attributes used by the @@ -701,7 +730,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, file: file, - ino: qid.Path, + ino: inoFromPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -756,8 +785,8 @@ func (d *dentry) cachedMetadataAuthoritative() bool { // updateFromP9Attrs is called to update d's metadata after an update from the // remote filesystem. -func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { - d.metadataMu.Lock() +// Precondition: d.metadataMu must be locked. +func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if mask.Mode { if got, want := uint32(attr.Mode.FileType()), d.fileType(); got != want { d.metadataMu.Unlock() @@ -791,11 +820,8 @@ func (d *dentry) updateFromP9Attrs(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.dataMu.Lock() - atomic.StoreUint64(&d.size, attr.Size) - d.dataMu.Unlock() + d.updateFileSizeLocked(attr.Size) } - d.metadataMu.Unlock() } // Preconditions: !d.isSynthetic() @@ -807,6 +833,10 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { file p9file handleMuRLocked bool ) + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() d.handleMu.RLock() if !d.handle.file.isNil() { file = d.handle.file @@ -822,7 +852,7 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { if err != nil { return err } - d.updateFromP9Attrs(attrMask, &attr) + d.updateFromP9AttrsLocked(attrMask, &attr) return nil } @@ -845,7 +875,7 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.UID = atomic.LoadUint32(&d.uid) stat.GID = atomic.LoadUint32(&d.gid) stat.Mode = uint16(atomic.LoadUint32(&d.mode)) - stat.Ino = d.ino + stat.Ino = uint64(d.ino) stat.Size = atomic.LoadUint64(&d.size) // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. @@ -858,7 +888,8 @@ func (d *dentry) statTo(stat *linux.Statx) { stat.DevMinor = d.fs.devMinor } -func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mnt *vfs.Mount) error { +func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions, mnt *vfs.Mount) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -866,7 +897,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } if err := mnt.CheckBeginWrite(); err != nil { @@ -883,14 +914,14 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin // Prepare for truncate. if stat.Mask&linux.STATX_SIZE != 0 { - switch d.mode & linux.S_IFMT { - case linux.S_IFREG: + switch mode.FileType() { + case linux.ModeRegular: if !setLocalMtime { // Truncate updates mtime. setLocalMtime = true stat.Mtime.Nsec = linux.UTIME_NOW } - case linux.S_IFDIR: + case linux.ModeDirectory: return syserror.EISDIR default: return syserror.EINVAL @@ -899,8 +930,25 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } d.metadataMu.Lock() defer d.metadataMu.Unlock() + if stat.Mask&linux.STATX_SIZE != 0 { + // The size needs to be changed even when + // !d.cachedMetadataAuthoritative() because d.mappings has to be + // updated. + d.updateFileSizeLocked(stat.Size) + } if !d.isSynthetic() { if stat.Mask != 0 { + if stat.Mask&linux.STATX_SIZE != 0 { + // Check whether to allow a truncate request to be made. + switch d.mode & linux.S_IFMT { + case linux.S_IFREG: + // Allow. + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } + } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -947,6 +995,8 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } else { atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) } + // Restore mask bits that we cleared earlier. + stat.Mask |= linux.STATX_ATIME } if setLocalMtime { if stat.Mtime.Nsec == linux.UTIME_NOW { @@ -954,48 +1004,56 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, stat *lin } else { atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) } + // Restore mask bits that we cleared earlier. + stat.Mask |= linux.STATX_MTIME } atomic.StoreInt64(&d.ctime, now) - if stat.Mask&linux.STATX_SIZE != 0 { + return nil +} + +// Preconditions: d.metadataMu must be locked. +func (d *dentry) updateFileSizeLocked(newSize uint64) { + d.dataMu.Lock() + oldSize := d.size + atomic.StoreUint64(&d.size, newSize) + // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings + // below. This allows concurrent calls to Read/Translate/etc. These + // functions synchronize with truncation by refusing to use cache + // contents beyond the new d.size. (We are still holding d.metadataMu, + // so we can't race with Write or another truncate.) + d.dataMu.Unlock() + if d.size < oldSize { + oldpgend, _ := usermem.PageRoundUp(oldSize) + newpgend, _ := usermem.PageRoundUp(d.size) + if oldpgend != newpgend { + d.mapsMu.Lock() + d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + d.mapsMu.Unlock() + } + // We are now guaranteed that there are no translations of + // truncated pages, and can remove them from the cache. Since + // truncated pages have been removed from the remote file, they + // should be dropped without being written back. d.dataMu.Lock() - oldSize := d.size - d.size = stat.Size - // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings - // below. This allows concurrent calls to Read/Translate/etc. These - // functions synchronize with truncation by refusing to use cache - // contents beyond the new d.size. (We are still holding d.metadataMu, - // so we can't race with Write or another truncate.) + d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) + d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) d.dataMu.Unlock() - if d.size < oldSize { - oldpgend, _ := usermem.PageRoundUp(oldSize) - newpgend, _ := usermem.PageRoundUp(d.size) - if oldpgend != newpgend { - d.mapsMu.Lock() - d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - d.mapsMu.Unlock() - } - // We are now guaranteed that there are no translations of - // truncated pages, and can remove them from the cache. Since - // truncated pages have been removed from the remote file, they - // should be dropped without being written back. - d.dataMu.Lock() - d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) - d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) - d.dataMu.Unlock() - } } - return nil } func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { + return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) +} + func dentryUIDFromP9UID(uid p9.UID) uint32 { if !uid.Ok() { return uint32(auth.OverflowUID) @@ -1051,15 +1109,34 @@ func (d *dentry) decRefLocked() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { + if d.isDir() { + events |= linux.IN_ISDIR + } + + d.fs.renameMu.RLock() + // The ordering below is important, Linux always notifies the parent first. + if d.parent != nil { + d.parent.watches.Notify(d.name, events, cookie, et, d.isDeleted()) + } + d.watches.Notify("", events, cookie, et, d.isDeleted()) + d.fs.renameMu.RUnlock() +} // Watches implements vfs.DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *dentry) Watches() *vfs.Watches { - return nil + return &d.watches +} + +// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. +// +// If no watches are left on this dentry and it has no references, cache it. +func (d *dentry) OnZeroWatches() { + if atomic.LoadInt64(&d.refs) == 0 { + d.fs.renameMu.Lock() + d.checkCachingLocked() + d.fs.renameMu.Unlock() + } } // checkCachingLocked should be called after d's reference count becomes 0 or it @@ -1093,6 +1170,9 @@ func (d *dentry) checkCachingLocked() { // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { + if d.isDeleted() { + d.watches.HandleDeletion() + } if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- @@ -1101,6 +1181,14 @@ func (d *dentry) checkCachingLocked() { d.destroyLocked() return } + // If d still has inotify watches and it is not deleted or invalidated, we + // cannot cache it and allow it to be evicted. Otherwise, we will lose its + // watches, even if a new dentry is created for the same file in the future. + // Note that the size of d.watches cannot concurrently transition from zero + // to non-zero, because adding a watch requires holding a reference on d. + if d.watches.Size() > 0 { + return + } // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1206,7 +1294,7 @@ func (d *dentry) setDeleted() { // We only support xattrs prefixed with "user." (see b/148380782). Currently, // there is no need to expose any other xattrs through a gofer. func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { - if d.file.isNil() { + if d.file.isNil() || !d.userXattrSupported() { return nil, nil } xattrMap, err := d.file.listXattr(ctx, size) @@ -1232,6 +1320,9 @@ func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vf if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { return "", syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return "", syserror.ENODATA + } return d.file.getXattr(ctx, opts.Name, opts.Size) } @@ -1245,6 +1336,9 @@ func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vf if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { return syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return syserror.EPERM + } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } @@ -1258,10 +1352,20 @@ func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { return syserror.EOPNOTSUPP } + if !d.userXattrSupported() { + return syserror.EPERM + } return d.file.removeXattr(ctx, name) } -// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDirectory(). +// Extended attributes in the user.* namespace are only supported for regular +// files and directories. +func (d *dentry) userXattrSupported() bool { + filetype := linux.FileMode(atomic.LoadUint32(&d.mode)).FileType() + return filetype == linux.ModeRegular || filetype == linux.ModeDirectory +} + +// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir(). func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { // O_TRUNC unconditionally requires us to obtain a new handle (opened with // O_TRUNC). @@ -1406,7 +1510,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { - return fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, fd.vfsfd.Mount()) + if err := fd.dentry().setStat(ctx, auth.CredentialsFromContext(ctx), &opts, fd.vfsfd.Mount()); err != nil { + return err + } + if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { + fd.dentry().InotifyWithParent(ev, 0, vfs.InodeEvent) + } + return nil } // Listxattr implements vfs.FileDescriptionImpl.Listxattr. @@ -1421,12 +1531,22 @@ func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOption // Setxattr implements vfs.FileDescriptionImpl.Setxattr. func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { - return fd.dentry().setxattr(ctx, auth.CredentialsFromContext(ctx), &opts) + d := fd.dentry() + if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { + return err + } + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // Removexattr implements vfs.FileDescriptionImpl.Removexattr. func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { - return fd.dentry().removexattr(ctx, auth.CredentialsFromContext(ctx), name) + d := fd.dentry() + if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { + return err + } + d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) + return nil } // LockBSD implements vfs.FileDescriptionImpl.LockBSD. diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 724a3f1f7..8792ca4f2 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -126,11 +126,16 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o } func (h *handle) sync(ctx context.Context) error { + // Handle most common case first. if h.fd >= 0 { ctx.UninterruptibleSleepStart(false) err := syscall.Fsync(int(h.fd)) ctx.UninterruptibleSleepFinish(false) return err } + if h.file.isNil() { + // File hasn't been touched, there is nothing to sync. + return nil + } return h.file.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 0d10cf7ac..09f142cfc 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -24,11 +24,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -67,12 +67,46 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { return d.handle.file.flush(ctx) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + + d := fd.dentry() + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + size := offset + length + + // Allocating a smaller size is a noop. + if size <= d.size { + return nil + } + + d.handleMu.Lock() + defer d.handleMu.Unlock() + + err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + if err != nil { + return err + } + d.dataMu.Lock() + atomic.StoreUint64(&d.size, size) + d.dataMu.Unlock() + if !d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } - if opts.Flags != 0 { + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } @@ -120,21 +154,53 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the fd was opened with O_APPEND, make sure the file size is updated. + // There is a possible race here if size is modified externally after + // metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } + } + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // Set offset to file size if the fd was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) + n, err := fd.pwriteLocked(ctx, src, offset, opts) + return n, offset + n, err +} +// Preconditions: fd.dentry().metatdataMu must be locked. +func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() if d.fs.opts.interop != InteropModeShared { // Compare Linux's mm/filemap.c:__generic_file_write_iter() => // file_update_time(). This is d.touchCMtime(), but without locking @@ -154,12 +220,12 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off return 0, syserror.EINVAL } mr := memmap.MappableRange{pgstart, pgend} - var freed []platform.FileRange + var freed []memmap.FileRange d.dataMu.Lock() cseg := d.cache.LowerBoundSegment(mr.Start) for cseg.Ok() && cseg.Start() < mr.End { cseg = d.cache.Isolate(cseg, mr) - freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) + freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()}) cseg = d.cache.Remove(cseg).NextSegment() } d.dataMu.Unlock() @@ -197,8 +263,8 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -489,15 +555,24 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() + newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence) + if err != nil { + return 0, err + } + fd.off = newOffset + return newOffset, nil +} + +// Calculate the new offset for a seek operation on a regular file. +func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int64, whence int32) (int64, error) { switch whence { case linux.SEEK_SET: // Use offset as specified. case linux.SEEK_CUR: - offset += fd.off + offset += fdOffset case linux.SEEK_END, linux.SEEK_DATA, linux.SEEK_HOLE: // Ensure file size is up to date. - d := fd.dentry() - if fd.filesystem().opts.interop == InteropModeShared { + if !d.cachedMetadataAuthoritative() { if err := d.updateFromGetattr(ctx); err != nil { return 0, err } @@ -525,7 +600,6 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( if offset < 0 { return 0, syserror.EINVAL } - fd.off = offset return offset, nil } @@ -536,20 +610,19 @@ func (fd *regularFileFD) Sync(ctx context.Context) error { func (d *dentry) syncSharedHandle(ctx context.Context) error { d.handleMu.RLock() - if !d.handleWritable { - d.handleMu.RUnlock() - return nil - } - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) - d.dataMu.Unlock() - if err == nil { - // Sync the remote file. - err = d.handle.sync(ctx) + defer d.handleMu.RUnlock() + + if d.handleWritable { + d.dataMu.Lock() + // Write dirty cached data to the remote file. + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } } - d.handleMu.RUnlock() - return err + // Sync the remote file. + return d.handle.sync(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -747,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. func (d *dentry) InvalidateUnsavable(ctx context.Context) error { - // Whether we have a host fd (and consequently what platform.File is + // Whether we have a host fd (and consequently what memmap.File is // mapped) can change across save/restore, so invalidate all translations // unconditionally. d.mapsMu.Lock() @@ -795,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { } } -// dentryPlatformFile implements platform.File. It exists solely because dentry -// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef. +// dentryPlatformFile implements memmap.File. It exists solely because dentry +// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file // is available (i.e. dentry.handle.fd >= 0), and that FD is used for @@ -804,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { type dentryPlatformFile struct { *dentry - // fdRefs counts references on platform.File offsets. fdRefs is protected + // fdRefs counts references on memmap.File offsets. fdRefs is protected // by dentry.dataMu. fdRefs fsutil.FrameRefSet @@ -816,29 +889,29 @@ type dentryPlatformFile struct { hostFileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. -func (d *dentryPlatformFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.IncRefAndAccount(fr) d.dataMu.Unlock() } -// DecRef implements platform.File.DecRef. -func (d *dentryPlatformFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { d.dataMu.Lock() d.fdRefs.DecRefAndAccount(fr) d.dataMu.Unlock() } -// MapInternal implements platform.File.MapInternal. -func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) d.handleMu.RUnlock() return bs, err } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() fd := d.handle.fd diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index e6e29b329..811528982 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -16,6 +16,7 @@ package gofer import ( "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -28,9 +29,9 @@ import ( ) // specialFileFD implements vfs.FileDescriptionImpl for pipes, sockets, device -// special files, and (when filesystemOptions.specialRegularFiles is in effect) -// regular files. specialFileFD differs from regularFileFD by using per-FD -// handles instead of shared per-dentry handles, and never buffering I/O. +// special files, and (when filesystemOptions.regularFilesUseSpecialFileFD is +// in effect) regular files. specialFileFD differs from regularFileFD by using +// per-FD handles instead of shared per-dentry handles, and never buffering I/O. type specialFileFD struct { fileDescription @@ -41,10 +42,10 @@ type specialFileFD struct { // file offset is significant, i.e. a regular file. seekable is immutable. seekable bool - // mayBlock is true if this file description represents a file for which - // queue may send I/O readiness events. mayBlock is immutable. - mayBlock bool - queue waiter.Queue + // haveQueue is true if this file description represents a file for which + // queue may send I/O readiness events. haveQueue is immutable. + haveQueue bool + queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex @@ -54,14 +55,14 @@ type specialFileFD struct { func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG - mayBlock := ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK + haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 fd := &specialFileFD{ - handle: h, - seekable: seekable, - mayBlock: mayBlock, + handle: h, + seekable: seekable, + haveQueue: haveQueue, } fd.LockFD.Init(locks) - if mayBlock && h.fd >= 0 { + if haveQueue { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err } @@ -70,7 +71,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, DenyPRead: !seekable, DenyPWrite: !seekable, }); err != nil { - if mayBlock && h.fd >= 0 { + if haveQueue { fdnotifier.RemoveFD(h.fd) } return nil, err @@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, // Release implements vfs.FileDescriptionImpl.Release. func (fd *specialFileFD) Release() { - if fd.mayBlock && fd.handle.fd >= 0 { + if fd.haveQueue { fdnotifier.RemoveFD(fd.handle.fd) } fd.handle.close(context.Background()) @@ -100,7 +101,7 @@ func (fd *specialFileFD) OnClose(ctx context.Context) error { // Readiness implements waiter.Waitable.Readiness. func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { - if fd.mayBlock { + if fd.haveQueue { return fdnotifier.NonBlockingPoll(fd.handle.fd, mask) } return fd.fileDescription.Readiness(mask) @@ -108,8 +109,9 @@ func (fd *specialFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { // EventRegister implements waiter.Waitable.EventRegister. func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventRegister(e, mask) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventRegister(e, mask) @@ -117,8 +119,9 @@ func (fd *specialFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { // EventUnregister implements waiter.Waitable.EventUnregister. func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { - if fd.mayBlock { + if fd.haveQueue { fd.queue.EventUnregister(e) + fdnotifier.UpdateFD(fd.handle.fd) return } fd.fileDescription.EventUnregister(e) @@ -129,7 +132,11 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs if fd.seekable && offset < 0 { return 0, syserror.EINVAL } - if opts.Flags != 0 { + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } @@ -138,7 +145,7 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't // hold here since specialFileFD doesn't client-cache data. Just buffer the // read instead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } buf := make([]byte, dst.NumBytes()) @@ -170,35 +177,76 @@ func (fd *specialFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *specialFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset, error. The final +// offset should be ignored by PWrite. +func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if fd.seekable && offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - if opts.Flags != 0 { - return 0, syserror.EOPNOTSUPP + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + d := fd.dentry() + // If the regular file fd was opened with O_APPEND, make sure the file size + // is updated. There is a possible race here if size is modified externally + // after metadata cache is updated. + if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } if fd.seekable { + // We need to hold the metadataMu *while* writing to a regular file. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Set offset to file size if the regular file was opened with O_APPEND. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Holding d.metadataMu is sufficient for reading d.size. + offset = int64(d.size) + } limit, err := vfs.CheckLimit(ctx, offset, src.NumBytes()) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(limit) } // Do a buffered write. See rationale in PRead. - if d := fd.dentry(); d.fs.opts.interop != InteropModeShared { + if d.cachedMetadataAuthoritative() { d.touchCMtime() } buf := make([]byte, src.NumBytes()) // Don't do partial writes if we get a partial read from src. if _, err := src.CopyIn(ctx, buf); err != nil { - return 0, err + return 0, offset, err } n, err := fd.handle.writeFromBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } - return int64(n), err + finalOff = offset + // Update file size for regular files. + if fd.seekable { + finalOff += int64(n) + // d.metadataMu is already locked at this point. + if uint64(finalOff) > d.size { + d.dataMu.Lock() + defer d.dataMu.Unlock() + atomic.StoreUint64(&d.size, uint64(finalOff)) + } + } + return int64(n), finalOff, err } // Write implements vfs.FileDescriptionImpl.Write. @@ -208,8 +256,8 @@ func (fd *specialFileFD) Write(ctx context.Context, src usermem.IOSequence, opts } fd.mu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.mu.Unlock() return n, err } @@ -221,27 +269,15 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( } fd.mu.Lock() defer fd.mu.Unlock() - switch whence { - case linux.SEEK_SET: - // Use offset as given. - case linux.SEEK_CUR: - offset += fd.off - default: - // SEEK_END, SEEK_DATA, and SEEK_HOLE aren't supported since it's not - // clear that file size is even meaningful for these files. - return 0, syserror.EINVAL - } - if offset < 0 { - return 0, syserror.EINVAL + newOffset, err := regularFileSeekLocked(ctx, fd.dentry(), fd.off, offset, whence) + if err != nil { + return 0, err } - fd.off = offset - return offset, nil + fd.off = newOffset + return newOffset, nil } // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - if !fd.vfsfd.IsWritable() { - return nil - } - return fd.handle.sync(ctx) + return fd.dentry().syncSharedHandle(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 1d5aa82dc..0eef4e16e 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -36,7 +36,7 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { } } -// Preconditions: fs.interop != InteropModeShared. +// Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime { return @@ -51,8 +51,8 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) { mnt.EndWrite() } -// Preconditions: fs.interop != InteropModeShared. The caller has successfully -// called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -60,8 +60,8 @@ func (d *dentry) touchCtime() { d.metadataMu.Unlock() } -// Preconditions: fs.interop != InteropModeShared. The caller has successfully -// called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCMtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -70,6 +70,8 @@ func (d *dentry) touchCMtime() { d.metadataMu.Unlock() } +// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has +// locked d.metadataMu. func (d *dentry) touchCMtimeLocked() { now := d.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&d.mtime, now) diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 44a09d87a..bd701bbc7 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -22,6 +22,7 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/fspath", + "//pkg/iovec", "//pkg/log", "//pkg/refs", "//pkg/safemem", @@ -33,7 +34,6 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/platform", "//pkg/sentry/socket/control", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 43a173bc9..c894f2ca0 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -259,7 +259,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL } @@ -373,7 +373,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { // SetStat implements kernfs.Inode. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - s := opts.Stat + s := &opts.Stat m := s.Mask if m == 0 { @@ -386,7 +386,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre if err := syscall.Fstat(i.hostFD, &hostStat); err != nil { return err } - if err := vfs.CheckSetStat(ctx, creds, &s, linux.FileMode(hostStat.Mode&linux.PermissionsMask), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, linux.FileMode(hostStat.Mode), auth.KUID(hostStat.Uid), auth.KGID(hostStat.Gid)); err != nil { return err } @@ -396,6 +396,9 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } } if m&linux.STATX_SIZE != 0 { + if hostStat.Mode&linux.S_IFMT != linux.S_IFREG { + return syserror.EINVAL + } if err := syscall.Ftruncate(i.hostFD, int64(s.Size)); err != nil { return err } @@ -459,10 +462,12 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u fileType := s.Mode & linux.FileTypeMask // Constrain flags to a subset we can handle. - // TODO(gvisor.dev/issue/1672): implement behavior corresponding to these allowed flags. - flags &= syscall.O_ACCMODE | syscall.O_DIRECT | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND + // + // TODO(gvisor.dev/issue/2601): Support O_NONBLOCK by adding RWF_NOWAIT to pread/pwrite calls. + flags &= syscall.O_ACCMODE | syscall.O_NONBLOCK | syscall.O_DSYNC | syscall.O_SYNC | syscall.O_APPEND - if fileType == syscall.S_IFSOCK { + switch fileType { + case syscall.S_IFSOCK: if i.isTTY { log.Warningf("cannot use host socket fd %d as TTY", i.hostFD) return nil, syserror.ENOTTY @@ -474,31 +479,33 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u } // Currently, we only allow Unix sockets to be imported. return unixsocket.NewFileDescription(ep, ep.Type(), flags, mnt, d, &i.locks) - } - // TODO(gvisor.dev/issue/1672): Allow only specific file types here, so - // that we don't allow importing arbitrary file types without proper - // support. - if i.isTTY { - fd := &TTYFileDescription{ - fileDescription: fileDescription{inode: i}, - termios: linux.DefaultSlaveTermios, + case syscall.S_IFREG, syscall.S_IFIFO, syscall.S_IFCHR: + if i.isTTY { + fd := &TTYFileDescription{ + fileDescription: fileDescription{inode: i}, + termios: linux.DefaultSlaveTermios, + } + fd.LockFD.Init(&i.locks) + vfsfd := &fd.vfsfd + if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return vfsfd, nil } + + fd := &fileDescription{inode: i} fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } return vfsfd, nil - } - fd := &fileDescription{inode: i} - fd.LockFD.Init(&i.locks) - vfsfd := &fd.vfsfd - if err := vfsfd.Init(fd, flags, mnt, d, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err + default: + log.Warningf("cannot import host fd %d with file type %o", i.hostFD, fileType) + return nil, syserror.EPERM } - return vfsfd, nil } // fileDescription is embedded by host fd implementations of FileDescriptionImpl. @@ -530,8 +537,8 @@ func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // Stat implements vfs.FileDescriptionImpl. -func (f *fileDescription) Stat(_ context.Context, opts vfs.StatOptions) (linux.Statx, error) { - return f.inode.Stat(f.vfsfd.Mount().Filesystem(), opts) +func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } // Release implements vfs.FileDescriptionImpl. @@ -539,6 +546,16 @@ func (f *fileDescription) Release() { // noop } +// Allocate implements vfs.FileDescriptionImpl. +func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { + if !f.inode.seekable { + return syserror.ESPIPE + } + + // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds. + return syserror.EOPNOTSUPP +} + // PRead implements FileDescriptionImpl. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode @@ -565,7 +582,7 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts } return n, err } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. + f.offsetMu.Lock() n, err := readFromHostFD(ctx, i.hostFD, dst, f.offset, opts.Flags) f.offset += n @@ -574,8 +591,10 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts } func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // TODO(gvisor.dev/issue/1672): Support select preadv2 flags. - if flags != 0 { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if flags&^linux.RWF_HIPRI != 0 { return 0, syserror.EOPNOTSUPP } reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) @@ -586,41 +605,58 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off // PWrite implements FileDescriptionImpl. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { - i := f.inode - if !i.seekable { + if !f.inode.seekable { return 0, syserror.ESPIPE } - return writeToHostFD(ctx, i.hostFD, src, offset, opts.Flags) + return f.writeToHostFD(ctx, src, offset, opts.Flags) } // Write implements FileDescriptionImpl. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode if !i.seekable { - n, err := writeToHostFD(ctx, i.hostFD, src, -1, opts.Flags) + n, err := f.writeToHostFD(ctx, src, -1, opts.Flags) if isBlockError(err) { err = syserror.ErrWouldBlock } return n, err } - // TODO(gvisor.dev/issue/1672): Cache pages, when forced to do so. - // TODO(gvisor.dev/issue/1672): Write to end of file and update offset if O_APPEND is set on this file. + f.offsetMu.Lock() - n, err := writeToHostFD(ctx, i.hostFD, src, f.offset, opts.Flags) + // NOTE(gvisor.dev/issue/2983): O_APPEND may cause memory corruption if + // another process modifies the host file between retrieving the file size + // and writing to the host fd. This is an unavoidable race condition because + // we cannot enforce synchronization on the host. + if f.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + var s syscall.Stat_t + if err := syscall.Fstat(i.hostFD, &s); err != nil { + f.offsetMu.Unlock() + return 0, err + } + f.offset = s.Size + } + n, err := f.writeToHostFD(ctx, src, f.offset, opts.Flags) f.offset += n f.offsetMu.Unlock() return n, err } -func writeToHostFD(ctx context.Context, hostFD int, src usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // TODO(gvisor.dev/issue/1672): Support select pwritev2 flags. +func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSequence, offset int64, flags uint32) (int64, error) { + hostFD := f.inode.hostFD + // TODO(gvisor.dev/issue/2601): Support select pwritev2 flags. if flags != 0 { return 0, syserror.EOPNOTSUPP } writer := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := src.CopyInTo(ctx, writer) hostfd.PutReadWriterAt(writer) + // NOTE(gvisor.dev/issue/2979): We always sync everything, even for O_DSYNC. + if n > 0 && f.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { + if syncErr := unix.Fsync(hostFD); syncErr != nil { + return int64(n), syncErr + } + } return int64(n), err } @@ -691,8 +727,7 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i // Sync implements FileDescriptionImpl. func (f *fileDescription) Sync(context.Context) error { - // TODO(gvisor.dev/issue/1672): Currently we do not support the SyncData - // optimization, so we always sync everything. + // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go index 8545a82f0..65d3af38c 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/host/mmap.go @@ -19,13 +19,12 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) -// inodePlatformFile implements platform.File. It exists solely because inode -// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef. +// inodePlatformFile implements memmap.File. It exists solely because inode +// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // // inodePlatformFile should only be used if inode.canMap is true. type inodePlatformFile struct { @@ -34,7 +33,7 @@ type inodePlatformFile struct { // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex - // fdRefs counts references on platform.File offsets. It is used solely for + // fdRefs counts references on memmap.File offsets. It is used solely for // memory accounting. fdRefs fsutil.FrameRefSet @@ -45,32 +44,32 @@ type inodePlatformFile struct { fileMapperInitOnce sync.Once } -// IncRef implements platform.File.IncRef. +// IncRef implements memmap.File.IncRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) IncRef(fr platform.FileRange) { +func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) i.fdRefsMu.Unlock() } -// DecRef implements platform.File.DecRef. +// DecRef implements memmap.File.DecRef. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) DecRef(fr platform.FileRange) { +func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) i.fdRefsMu.Unlock() } -// MapInternal implements platform.File.MapInternal. +// MapInternal implements memmap.File.MapInternal. // // Precondition: i.inode.canMap must be true. -func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (i *inodePlatformFile) FD() int { return i.hostFD } diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 38f1fbfba..fd16bd92d 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -47,11 +47,6 @@ func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transpor return ep, nil } -// maxSendBufferSize is the maximum host send buffer size allowed for endpoint. -// -// N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). -const maxSendBufferSize = 8 << 20 - // ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and // transport.Receiver. It is backed by a host fd that was imported at sentry // startup. This fd is shared with a hostfs inode, which retains ownership of @@ -114,10 +109,6 @@ func (c *ConnectedEndpoint) init() *syserr.Error { if err != nil { return syserr.FromError(err) } - if sndbuf > maxSendBufferSize { - log.Warningf("Socket send buffer too large: %d", sndbuf) - return syserr.ErrInvalidEndpointState - } c.stype = linux.SockType(stype) c.sndbuf = int64(sndbuf) diff --git a/pkg/sentry/fsimpl/host/socket_iovec.go b/pkg/sentry/fsimpl/host/socket_iovec.go index 584c247d2..fc0d5fd38 100644 --- a/pkg/sentry/fsimpl/host/socket_iovec.go +++ b/pkg/sentry/fsimpl/host/socket_iovec.go @@ -17,13 +17,10 @@ package host import ( "syscall" - "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/syserror" ) -// maxIovs is the maximum number of iovecs to pass to the host. -var maxIovs = linux.UIO_MAXIOV - // copyToMulti copies as many bytes from src to dst as possible. func copyToMulti(dst [][]byte, src []byte) { for _, d := range dst { @@ -74,7 +71,7 @@ func buildIovec(bufs [][]byte, maxlen int64, truncate bool) (length int64, iovec } } - if iovsRequired > maxIovs { + if iovsRequired > iovec.MaxIovs { // The kernel will reject our call if we pass this many iovs. // Use a single intermediate buffer instead. b := make([]byte, stopLen) diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 179df6c1e..3835557fe 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -70,6 +70,6 @@ go_test( "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index c1215b70a..c6c4472e7 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -101,12 +101,12 @@ func (fd *DynamicBytesFD) Seek(ctx context.Context, offset int64, whence int32) return fd.DynamicBytesFileDescriptionImpl.Seek(ctx, offset, whence) } -// Read implmenets vfs.FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (fd *DynamicBytesFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { return fd.DynamicBytesFileDescriptionImpl.Read(ctx, dst, opts) } -// PRead implmenets vfs.FileDescriptionImpl.PRead. +// PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DynamicBytesFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return fd.DynamicBytesFileDescriptionImpl.PRead(ctx, dst, offset, opts) } @@ -127,7 +127,7 @@ func (fd *DynamicBytesFD) Release() {} // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *DynamicBytesFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(fs, opts) + return fd.inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index 5f7853a2a..1d37ccb98 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -112,7 +112,7 @@ func (fd *GenericDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence return fd.DirectoryFileDescriptionDefaultImpl.PWrite(ctx, src, offset, opts) } -// Release implements vfs.FileDecriptionImpl.Release. +// Release implements vfs.FileDescriptionImpl.Release. func (fd *GenericDirectoryFD) Release() {} func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { @@ -123,7 +123,7 @@ func (fd *GenericDirectoryFD) inode() Inode { return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode } -// IterDirents implements vfs.FileDecriptionImpl.IterDirents. IterDirents holds +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds // o.mu when calling cb. func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { fd.mu.Lock() @@ -132,7 +132,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent opts := vfs.StatOptions{Mask: linux.STATX_INO} // Handle ".". if fd.off == 0 { - stat, err := fd.inode().Stat(fd.filesystem(), opts) + stat, err := fd.inode().Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -152,7 +152,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent if fd.off == 1 { vfsd := fd.vfsfd.VirtualDentry().Dentry() parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode - stat, err := parentInode.Stat(fd.filesystem(), opts) + stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -176,7 +176,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent childIdx := fd.off - 2 for it := fd.children.nthLocked(childIdx); it != nil; it = it.Next() { inode := it.Dentry.Impl().(*Dentry).inode - stat, err := inode.Stat(fd.filesystem(), opts) + stat, err := inode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err } @@ -198,7 +198,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent return err } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() defer fd.mu.Unlock() @@ -226,7 +226,7 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := fd.filesystem() inode := fd.inode() - return inode.Stat(fs, opts) + return inode.Stat(ctx, fs, opts) } // SetStat implements vfs.FileDescriptionImpl.SetStat. @@ -236,6 +236,11 @@ func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptio return inode.SetStat(ctx, fd.filesystem(), creds, opts) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *GenericDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return fd.DirectoryFileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *GenericDirectoryFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 8939871c1..61a36cff9 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -684,7 +684,7 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } - return inode.Stat(fs.VFSFilesystem(), opts) + return inode.Stat(ctx, fs.VFSFilesystem(), opts) } // StatFSAt implements vfs.FilesystemImpl.StatFSAt. diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 650bd7b88..579e627f0 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -243,7 +243,7 @@ func (a *InodeAttrs) Mode() linux.FileMode { // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. -func (a *InodeAttrs) Stat(*vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { +func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK stat.DevMajor = a.devMajor @@ -267,7 +267,7 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { return syserror.EPERM } - if err := vfs.CheckSetStat(ctx, creds, &opts.Stat, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -293,6 +293,8 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut // inode numbers are immutable after node creation. // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. + // Also, STATX_SIZE will need some special handling, because read-only static + // files should return EIO for truncate operations. return nil } @@ -469,6 +471,8 @@ func (o *OrderedChildren) Unlink(ctx context.Context, name string, child *vfs.De if err := o.checkExistingLocked(name, child); err != nil { return err } + + // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. o.removeLocked(name) return nil } @@ -516,6 +520,8 @@ func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, c if err := o.checkExistingLocked(oldname, child); err != nil { return nil, err } + + // TODO(gvisor.dev/issue/3027): Check sticky bit before removing. replaced := dst.replaceChildLocked(newname, child) return replaced, nil } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index bbee8ccda..46f207664 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -227,16 +227,19 @@ func (d *Dentry) destroy() { // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) {} +// Although Linux technically supports inotify on pseudo filesystems (inotify +// is implemented at the vfs layer), it is not particularly useful. It is left +// unimplemented until someone actually needs it. +func (d *Dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) {} // Watches implements vfs.DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *Dentry) Watches() *vfs.Watches { return nil } +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +func (d *Dentry) OnZeroWatches() {} + // InsertChild inserts child into the vfs dentry cache with the given name under // this dentry. This does not update the directory inode, so calling this on // its own isn't sufficient to insert a child into a directory. InsertChild @@ -343,7 +346,7 @@ type inodeMetadata interface { // Stat returns the metadata for this inode. This corresponds to // vfs.FilesystemImpl.StatAt. - Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) + Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) // SetStat updates the metadata for this inode. This corresponds to // vfs.FilesystemImpl.SetStatAt. Implementations are responsible for checking @@ -425,10 +428,10 @@ type inodeDynamicLookup interface { // IterDirents is used to iterate over dynamically created entries. It invokes // cb on each entry in the directory represented by the FileDescription. // 'offset' is the offset for the entire IterDirents call, which may include - // results from the caller. 'relOffset' is the offset inside the entries - // returned by this IterDirents invocation. In other words, - // 'offset+relOffset+1' is the value that should be set in vfs.Dirent.NextOff, - // while 'relOffset' is the place where iteration should start from. + // results from the caller (e.g. "." and ".."). 'relOffset' is the offset + // inside the entries returned by this IterDirents invocation. In other words, + // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as + // the return value, while 'relOffset' is the place to start iteration. IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index ff82e1f20..6b705e955 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1104,7 +1104,7 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, rp.Credentials(), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := rp.Mount() diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index a3c1f7a8d..c0749e711 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -151,7 +151,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) - if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts.Stat, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { return err } mnt := fd.vfsfd.Mount() @@ -176,7 +176,7 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// StatFS implements vfs.FileDesciptionImpl.StatFS. +// StatFS implements vfs.FileDescriptionImpl.StatFS. func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e11a3ff19..e720d4825 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -528,6 +528,11 @@ func (d *dentry) Watches() *vfs.Watches { return nil } +// OnZeroWatches implements vfs.DentryImpl.OnZeroWatches. +// +// TODO(gvisor.dev/issue/1479): Implement inotify. +func (d *dentry) OnZeroWatches() {} + // iterLayers invokes yield on each layer comprising d, from top to bottom. If // any call to yield returns false, iterLayer stops iteration. func (d *dentry) iterLayers(yield func(vd vfs.VirtualDentry, isUpper bool) bool) { diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index dd7eaf4a8..811f80a5f 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -115,7 +115,7 @@ func (i *inode) Mode() linux.FileMode { } // Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (i *inode) Stat(_ context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { ts := linux.NsecToStatxTimestamp(i.ctime.Nanoseconds()) return linux.Statx{ Mask: linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_NLINK | linux.STATX_UID | linux.STATX_GID | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME | linux.STATX_INO | linux.STATX_SIZE | linux.STATX_BLOCKS, diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 36a89540c..79c2725f3 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -128,7 +128,7 @@ func (fd *subtasksFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallbac return fd.GenericDirectoryFD.IterDirents(ctx, cb) } -// Seek implements vfs.FileDecriptionImpl.Seek. +// Seek implements vfs.FileDescriptionImpl.Seek. func (fd *subtasksFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { if fd.task.ExitState() >= kernel.TaskExitZombie { return 0, syserror.ENOENT @@ -165,8 +165,8 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v } // Stat implements kernfs.Inode. -func (i *subtasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 8bb2b0ce1..a5c7aa470 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -156,8 +156,8 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. } // Stat implements kernfs.Inode. -func (i *taskOwnedInode) Stat(fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.Inode.Stat(fs, opts) +func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 7debdb07a..fea29e5f0 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDynamicLookup. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, absOffset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -72,7 +72,6 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, abs } }) - offset := absOffset + relOffset typ := uint8(linux.DT_REG) if i.produceSymlink { typ = linux.DT_LNK diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index ba4405026..859b7d727 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -35,6 +35,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// "There is an (arbitrary) limit on the number of lines in the file. As at +// Linux 3.18, the limit is five lines." - user_namespaces(7) +const maxIDMapLines = 5 + // mm gets the kernel task's MemoryManager. No additional reference is taken on // mm here. This is safe because MemoryManager.destroy is required to leave the // MemoryManager in a state where it's still usable as a DynamicBytesSource. @@ -227,8 +231,9 @@ func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Linux will return envp up to and including the first NULL character, // so find it. - if end := bytes.IndexByte(buf.Bytes()[ar.Length():], 0); end != -1 { - buf.Truncate(end) + envStart := int(ar.Length()) + if nullIdx := bytes.IndexByte(buf.Bytes()[envStart:], 0); nullIdx != -1 { + buf.Truncate(envStart + nullIdx) } } @@ -283,7 +288,8 @@ func (d *commData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -// idMapData implements vfs.DynamicBytesSource for /proc/[pid]/{gid_map|uid_map}. +// idMapData implements vfs.WritableDynamicBytesSource for +// /proc/[pid]/{gid_map|uid_map}. // // +stateify savable type idMapData struct { @@ -295,7 +301,7 @@ type idMapData struct { var _ dynamicInode = (*idMapData)(nil) -// Generate implements vfs.DynamicBytesSource.Generate. +// Generate implements vfs.WritableDynamicBytesSource.Generate. func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { var entries []auth.IDMapEntry if d.gids { @@ -309,6 +315,60 @@ func (d *idMapData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // "In addition, the number of bytes written to the file must be less than + // the system page size, and the write must be performed at the start of + // the file ..." - user_namespaces(7) + srclen := src.NumBytes() + if srclen >= usermem.PageSize || offset != 0 { + return 0, syserror.EINVAL + } + b := make([]byte, srclen) + if _, err := src.CopyIn(ctx, b); err != nil { + return 0, err + } + + // Truncate from the first NULL byte. + var nul int64 + nul = int64(bytes.IndexByte(b, 0)) + if nul == -1 { + nul = srclen + } + b = b[:nul] + // Remove the last \n. + if nul >= 1 && b[nul-1] == '\n' { + b = b[:nul-1] + } + lines := bytes.SplitN(b, []byte("\n"), maxIDMapLines+1) + if len(lines) > maxIDMapLines { + return 0, syserror.EINVAL + } + + entries := make([]auth.IDMapEntry, len(lines)) + for i, l := range lines { + var e auth.IDMapEntry + _, err := fmt.Sscan(string(l), &e.FirstID, &e.FirstParentID, &e.Length) + if err != nil { + return 0, syserror.EINVAL + } + entries[i] = e + } + var err error + if d.gids { + err = d.task.UserNamespace().SetGIDMap(ctx, entries) + } else { + err = d.task.UserNamespace().SetUIDMap(ctx, entries) + } + if err != nil { + return 0, err + } + + // On success, Linux's kernel/user_namespace.c:map_write() always returns + // count, even if fewer bytes were used. + return int64(srclen), nil +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -816,7 +876,7 @@ var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) // Stat implements FileDescriptionImpl. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() - return fd.inode.Stat(vfs, opts) + return fd.inode.Stat(ctx, vfs, opts) } // SetStat implements FileDescriptionImpl. diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 2f214d0c2..6d2b90a8b 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -206,8 +206,8 @@ func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs. return fd.VFSFileDescription(), nil } -func (i *tasksInode) Stat(vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - stat, err := i.InodeAttrs.Stat(vsfs, opts) +func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index a741e2bb6..1b548ccd4 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -29,6 +29,6 @@ go_test( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index fe02f7ee9..01ce30a4d 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -138,7 +138,7 @@ type cpuFile struct { // Generate implements vfs.DynamicBytesSource.Generate. func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { - fmt.Fprintf(buf, "0-%d", c.maxCores-1) + fmt.Fprintf(buf, "0-%d\n", c.maxCores-1) return nil } diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 4b3602d47..242d5fd12 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -51,7 +51,7 @@ func TestReadCPUFile(t *testing.T) { k := kernel.KernelFromContext(s.Ctx) maxCPUCores := k.ApplicationCores() - expected := fmt.Sprintf("0-%d", maxCPUCores-1) + expected := fmt.Sprintf("0-%d\n", maxCPUCores-1) for _, fname := range []string{"online", "possible", "present"} { pop := s.PathOpAtRoot(fmt.Sprintf("devices/system/cpu/%s", fname)) diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index 0e4053a46..400a97996 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -32,6 +32,6 @@ go_library( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/usermem", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index c16a36cdb..e743e8114 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -62,6 +62,7 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, } @@ -73,7 +74,7 @@ func Boot() (*kernel.Kernel, error) { k.SetMemoryFile(mf) // Pass k as the platform since it is savable, unlike the actual platform. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -103,11 +104,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } - kernel.VFS2Enabled = true - - if err := k.VFS().Init(); err != nil { - return nil, fmt.Errorf("VFS init: %v", err) - } k.VFS().MustRegisterFilesystemType(tmpfs.Name, &tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, AllowUserList: true, diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 913b8a6c5..0a1ad4765 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -79,7 +79,10 @@ func (dir *directory) removeChildLocked(child *dentry) { dir.iterMu.Lock() dir.childList.Remove(child) dir.iterMu.Unlock() - child.unlinked = true +} + +func (dir *directory) mayDelete(creds *auth.Credentials, child *dentry) error { + return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&dir.inode.mode)), auth.KUID(atomic.LoadUint32(&child.inode.uid))) } type directoryFD struct { @@ -107,13 +110,14 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba fs := fd.filesystem() dir := fd.inode().impl.(*directory) + defer fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + // fs.mu is required to read d.parent and dentry.name. fs.mu.RLock() defer fs.mu.RUnlock() dir.iterMu.Lock() defer dir.iterMu.Unlock() - fd.dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) fd.inode().touchAtime(fd.vfsfd.Mount()) if fd.off == 0 { diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 72399b321..ef210a69b 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -79,7 +79,7 @@ afterSymlink: } if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // Symlink traversal updates access time. - atomic.StoreInt64(&d.inode.atime, d.inode.fs.clock.Now().Nanoseconds()) + child.inode.touchAtime(rp.Mount()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -182,7 +182,7 @@ func (fs *filesystem) doCreateAt(rp *vfs.ResolvingPath, dir bool, create func(pa if dir { ev |= linux.IN_ISDIR } - parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent) + parentDir.inode.watches.Notify(name, uint32(ev), 0, vfs.InodeEvent, false /* unlinked */) parentDir.inode.touchCMtime() return nil } @@ -237,18 +237,22 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. return syserror.EXDEV } d := vd.Dentry().Impl().(*dentry) - if d.inode.isDir() { + i := d.inode + if i.isDir() { return syserror.EPERM } - if d.inode.nlink == 0 { + if err := vfs.MayLink(auth.CredentialsFromContext(ctx), linux.FileMode(atomic.LoadUint32(&i.mode)), auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + return err + } + if i.nlink == 0 { return syserror.ENOENT } - if d.inode.nlink == maxLinks { + if i.nlink == maxLinks { return syserror.EMLINK } - d.inode.incLinksLocked() - d.inode.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent) - parentDir.insertChildLocked(fs.newDentry(d.inode), name) + i.incLinksLocked() + i.watches.Notify("", linux.IN_ATTRIB, 0, vfs.InodeEvent, false /* unlinked */) + parentDir.insertChildLocked(fs.newDentry(i), name) return nil }) } @@ -273,7 +277,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v creds := rp.Credentials() var childInode *inode switch opts.Mode.FileType() { - case 0, linux.S_IFREG: + case linux.S_IFREG: childInode = fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) case linux.S_IFIFO: childInode = fs.newNamedPipe(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode) @@ -364,10 +368,13 @@ afterTrailingSymlink: if err != nil { return nil, err } - parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent) + parentDir.inode.watches.Notify(name, linux.IN_CREATE, 0, vfs.PathEvent, false /* unlinked */) parentDir.inode.touchCMtime() return fd, nil } + if mustCreate { + return nil, syserror.EEXIST + } // Is the file mounted over? if err := rp.CheckMount(&child.vfsd); err != nil { return nil, err @@ -375,7 +382,7 @@ afterTrailingSymlink: // Do we need to resolve a trailing symlink? if symlink, ok := child.inode.impl.(*symlink); ok && rp.ShouldFollowSymlink() { // Symlink traversal updates access time. - atomic.StoreInt64(&child.inode.atime, child.inode.fs.clock.Now().Nanoseconds()) + child.inode.touchAtime(rp.Mount()) if err := rp.HandleSymlink(symlink.target); err != nil { return nil, err } @@ -400,10 +407,10 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open case *regularFile: var fd regularFileFD fd.LockFD.Init(&d.inode.locks) - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil { return nil, err } - if opts.Flags&linux.O_TRUNC != 0 { + if !afterCreate && opts.Flags&linux.O_TRUNC != 0 { if _, err := impl.truncate(0); err != nil { return nil, err } @@ -416,7 +423,7 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open } var fd directoryFD fd.LockFD.Init(&d.inode.locks) - if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + if err := fd.vfsfd.Init(&fd, opts.Flags, rp.Mount(), &d.vfsd, &vfs.FileDescriptionOptions{AllowDirectIO: true}); err != nil { return nil, err } return &fd.vfsfd, nil @@ -485,6 +492,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if !ok { return syserror.ENOENT } + if err := oldParentDir.mayDelete(rp.Credentials(), renamed); err != nil { + return err + } // Note that we don't need to call rp.CheckMount(), since if renamed is a // mount point then we want to rename the mount point, not anything in the // mounted filesystem. @@ -599,6 +609,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if !ok { return syserror.ENOENT } + if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { + return err + } childDir, ok := child.inode.impl.(*directory) if !ok { return syserror.ENOTDIR @@ -618,7 +631,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } parentDir.removeChildLocked(child) - parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent) + parentDir.inode.watches.Notify(name, linux.IN_DELETE|linux.IN_ISDIR, 0, vfs.InodeEvent, true /* unlinked */) // Remove links for child, child/., and child/.. child.inode.decLinksLocked() child.inode.decLinksLocked() @@ -631,14 +644,16 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // SetStatAt implements vfs.FilesystemImpl.SetStatAt. func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetStatOptions) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } - if err := d.inode.setStat(ctx, rp.Credentials(), &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, rp.Credentials(), &opts); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() if ev := vfs.InotifyEventFromStatMask(opts.Stat.Mask); ev != 0 { d.InotifyWithParent(ev, 0, vfs.InodeEvent) @@ -707,6 +722,9 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if !ok { return syserror.ENOENT } + if err := parentDir.mayDelete(rp.Credentials(), child); err != nil { + return err + } if child.inode.isDir() { return syserror.EISDIR } @@ -781,14 +799,16 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil @@ -797,14 +817,16 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt // RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(rp) if err != nil { + fs.mu.RUnlock() return err } if err := d.inode.removexattr(rp.Credentials(), name); err != nil { + fs.mu.RUnlock() return err } + fs.mu.RUnlock() d.InotifyWithParent(linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index bfd9c5995..abbaa5d60 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -274,11 +274,35 @@ func (fd *regularFileFD) Release() { // noop } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + f := fd.inode().impl.(*regularFile) + + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + oldSize := f.size + size := offset + length + if oldSize >= size { + return nil + } + _, err := f.truncateLocked(size) + return err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } + + // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since + // all state is in-memory. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { + return 0, syserror.EOPNOTSUPP + } + if dst.NumBytes() == 0 { return 0, nil } @@ -301,40 +325,60 @@ func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { if offset < 0 { - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL + } + + // Check that flags are supported. RWF_DSYNC/RWF_SYNC can be ignored since + // all state is in-memory. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^(linux.RWF_HIPRI|linux.RWF_DSYNC|linux.RWF_SYNC) != 0 { + return 0, offset, syserror.EOPNOTSUPP } + srclen := src.NumBytes() if srclen == 0 { - return 0, nil + return 0, offset, nil } f := fd.inode().impl.(*regularFile) + f.inode.mu.Lock() + defer f.inode.mu.Unlock() + // If the file is opened with O_APPEND, update offset to file size. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking f.inode.mu is sufficient for reading f.size. + offset = int64(f.size) + } if end := offset + srclen; end < offset { // Overflow. - return 0, syserror.EINVAL + return 0, offset, syserror.EINVAL } - var err error srclen, err = vfs.CheckLimit(ctx, offset, srclen) if err != nil { - return 0, err + return 0, offset, err } src = src.TakeFirst64(srclen) - f.inode.mu.Lock() rw := getRegularFileReadWriter(f, offset) n, err := src.CopyInTo(ctx, rw) - fd.inode().touchCMtimeLocked() - f.inode.mu.Unlock() + f.inode.touchCMtimeLocked() putRegularFileReadWriter(rw) - return n, err + return n, n + offset, err } // Write implements vfs.FileDescriptionImpl.Write. func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { fd.offMu.Lock() - n, err := fd.PWrite(ctx, src, fd.off, opts) - fd.off += n + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off fd.offMu.Unlock() return n, err } diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index a94333ee0..2545d88e9 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -215,11 +215,6 @@ type dentry struct { // filesystem.mu. name string - // unlinked indicates whether this dentry has been unlinked from its parent. - // It is only set to true on an unlink operation, and never set from true to - // false. unlinked is protected by filesystem.mu. - unlinked bool - // dentryEntry (ugh) links dentries into their parent directory.childList. dentryEntry @@ -259,18 +254,22 @@ func (d *dentry) DecRef() { } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. -func (d *dentry) InotifyWithParent(events uint32, cookie uint32, et vfs.EventType) { +func (d *dentry) InotifyWithParent(events, cookie uint32, et vfs.EventType) { if d.inode.isDir() { events |= linux.IN_ISDIR } + // tmpfs never calls VFS.InvalidateDentry(), so d.vfsd.IsDead() indicates + // that d was deleted. + deleted := d.vfsd.IsDead() + + d.inode.fs.mu.RLock() // The ordering below is important, Linux always notifies the parent first. if d.parent != nil { - // Note that d.parent or d.name may be stale if there is a concurrent - // rename operation. Inotify does not provide consistency guarantees. - d.parent.inode.watches.NotifyWithExclusions(d.name, events, cookie, et, d.unlinked) + d.parent.inode.watches.Notify(d.name, events, cookie, et, deleted) } - d.inode.watches.Notify("", events, cookie, et) + d.inode.watches.Notify("", events, cookie, et, deleted) + d.inode.fs.mu.RUnlock() } // Watches implements vfs.DentryImpl.Watches. @@ -278,6 +277,9 @@ func (d *dentry) Watches() *vfs.Watches { return &d.inode.watches } +// OnZeroWatches implements vfs.Dentry.OnZeroWatches. +func (d *dentry) OnZeroWatches() {} + // inode represents a filesystem object. type inode struct { // fs is the owning filesystem. fs is immutable. @@ -336,7 +338,6 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.ctime = now i.mtime = now // i.nlink initialized by caller - i.watches = vfs.Watches{} i.impl = impl } @@ -451,7 +452,8 @@ func (i *inode) statTo(stat *linux.Statx) { } } -func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx) error { +func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs.SetStatOptions) error { + stat := &opts.Stat if stat.Mask == 0 { return nil } @@ -459,7 +461,7 @@ func (i *inode) setStat(ctx context.Context, creds *auth.Credentials, stat *linu return syserror.EPERM } mode := linux.FileMode(atomic.LoadUint32(&i.mode)) - if err := vfs.CheckSetStat(ctx, creds, stat, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { + if err := vfs.CheckSetStat(ctx, creds, opts, mode, auth.KUID(atomic.LoadUint32(&i.uid)), auth.KGID(atomic.LoadUint32(&i.gid))); err != nil { return err } i.mu.Lock() @@ -694,7 +696,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) d := fd.dentry() - if err := d.inode.setStat(ctx, creds, &opts.Stat); err != nil { + if err := d.inode.setStat(ctx, creds, &opts); err != nil { return err } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 1510a7c26..f6886a758 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -132,6 +132,7 @@ go_library( "task_stop.go", "task_syscall.go", "task_usermem.go", + "task_work.go", "thread_group.go", "threads.go", "timekeeper.go", @@ -200,6 +201,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/state", "//pkg/state/statefile", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 3d78cd48f..4c0f1e41f 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -107,7 +107,7 @@ type EventPoll struct { // different lock to avoid circular lock acquisition order involving // the wait queue mutexes and mu. The full order is mu, observed file // wait queue mutex, then listsMu; this allows listsMu to be acquired - // when readyCallback is called. + // when (*pollEntry).Callback is called. // // An entry is always in one of the following lists: // readyList -- when there's a chance that it's ready to have @@ -116,7 +116,7 @@ type EventPoll struct { // readEvents() functions always call the entry's file // Readiness() function to confirm it's ready. // waitingList -- when there's no chance that the entry is ready, - // so it's waiting for the readyCallback to be called + // so it's waiting for the (*pollEntry).Callback to be called // on it before it gets moved to the readyList. // disabledList -- when the entry is disabled. This happens when // a one-shot entry gets delivered via readEvents(). @@ -269,21 +269,19 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { return ret } -// readyCallback is called when one of the files we're polling becomes ready. It -// moves said file to the readyList if it's currently in the waiting list. -type readyCallback struct{} - // Callback implements waiter.EntryCallback.Callback. -func (*readyCallback) Callback(w *waiter.Entry) { - entry := w.Context.(*pollEntry) - e := entry.epoll +// +// Callback is called when one of the files we're polling becomes ready. It +// moves said file to the readyList if it's currently in the waiting list. +func (p *pollEntry) Callback(*waiter.Entry) { + e := p.epoll e.listsMu.Lock() - if entry.curList == &e.waitingList { - e.waitingList.Remove(entry) - e.readyList.PushBack(entry) - entry.curList = &e.readyList + if p.curList == &e.waitingList { + e.waitingList.Remove(p) + e.readyList.PushBack(p) + p.curList = &e.readyList e.listsMu.Unlock() e.Notify(waiter.EventIn) @@ -310,7 +308,7 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { // Check if the file happens to already be in a ready state. ready := f.Readiness(entry.mask) & entry.mask if ready != 0 { - (*readyCallback).Callback(nil, &entry.waiter) + entry.Callback(&entry.waiter) } } @@ -380,10 +378,9 @@ func (e *EventPoll) AddEntry(id FileIdentifier, flags EntryFlags, mask waiter.Ev userData: data, epoll: e, flags: flags, - waiter: waiter.Entry{Callback: &readyCallback{}}, mask: mask, } - entry.waiter.Context = entry + entry.waiter.Callback = entry e.files[id] = entry entry.file = refs.NewWeakRef(id.File, entry) @@ -406,7 +403,7 @@ func (e *EventPoll) UpdateEntry(id FileIdentifier, flags EntryFlags, mask waiter } // Unregister the old mask and remove entry from the list it's in, so - // readyCallback is guaranteed to not be called on this entry anymore. + // (*pollEntry).Callback is guaranteed to not be called on this entry anymore. entry.id.File.EventUnregister(&entry.waiter) // Remove entry from whatever list it's in. This ensure that no other diff --git a/pkg/sentry/kernel/epoll/epoll_state.go b/pkg/sentry/kernel/epoll/epoll_state.go index 8e9f200d0..7c61e0258 100644 --- a/pkg/sentry/kernel/epoll/epoll_state.go +++ b/pkg/sentry/kernel/epoll/epoll_state.go @@ -21,8 +21,7 @@ import ( // afterLoad is invoked by stateify. func (p *pollEntry) afterLoad() { - p.waiter = waiter.Entry{Callback: &readyCallback{}} - p.waiter.Context = p + p.waiter.Callback = p p.file = refs.NewWeakRef(p.id.File, p) p.id.File.EventRegister(&p.waiter, p.mask) } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index b9126e946..2b3955598 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -11,6 +11,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/vfs", "//pkg/sync", "//pkg/waiter", ], diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index d32c3e90a..153d2cd9b 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -20,15 +20,21 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new FileAsync. +// New creates a new fs.FileAsync. func New() fs.FileAsync { return &FileAsync{} } +// NewVFS2 creates a new vfs.FileAsync. +func NewVFS2() vfs.FileAsync { + return &FileAsync{} +} + // FileAsync sends signals when the registered file is ready for IO. // // +stateify savable @@ -170,3 +176,13 @@ func (a *FileAsync) SetOwnerProcessGroup(requester *kernel.Task, recipient *kern a.recipientTG = nil a.recipientPG = recipient } + +// ClearOwner unsets the current signal recipient. +func (a *FileAsync) ClearOwner() { + a.mu.Lock() + defer a.mu.Unlock() + a.requester = nil + a.recipientT = nil + a.recipientTG = nil + a.recipientPG = nil +} diff --git a/pkg/sentry/kernel/futex/futex.go b/pkg/sentry/kernel/futex/futex.go index 732e66da4..bcc1b29a8 100644 --- a/pkg/sentry/kernel/futex/futex.go +++ b/pkg/sentry/kernel/futex/futex.go @@ -717,10 +717,10 @@ func (m *Manager) lockPILocked(w *Waiter, t Target, addr usermem.Addr, tid uint3 } } -// UnlockPI unlock the futex following the Priority-inheritance futex -// rules. The address provided must contain the caller's TID. If there are -// waiters, TID of the next waiter (FIFO) is set to the given address, and the -// waiter woken up. If there are no waiters, 0 is set to the address. +// UnlockPI unlocks the futex following the Priority-inheritance futex rules. +// The address provided must contain the caller's TID. If there are waiters, +// TID of the next waiter (FIFO) is set to the given address, and the waiter +// woken up. If there are no waiters, 0 is set to the address. func (m *Manager) UnlockPI(t Target, addr usermem.Addr, tid uint32, private bool) error { k, err := getKey(t, addr, private) if err != nil { diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 554a42e05..15dae0f5b 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -34,7 +34,6 @@ package kernel import ( "errors" "fmt" - "io" "path/filepath" "sync/atomic" "time" @@ -73,6 +72,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/uniqueid" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -81,6 +81,10 @@ import ( // easy access everywhere. To be removed once VFS2 becomes the default. var VFS2Enabled = false +// FUSEEnabled is set to true when FUSE is enabled. Added as a global for allow +// easy access everywhere. To be removed once FUSE is completed. +var FUSEEnabled = false + // Kernel represents an emulated Linux kernel. It must be initialized by calling // Init() or LoadFrom(). // @@ -417,7 +421,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w io.Writer) error { +func (k *Kernel) SaveTo(w wire.Writer) error { saveStart := time.Now() ctx := k.SupervisorContext() @@ -473,18 +477,18 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { + if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { + stats, err := state.Save(k.SupervisorContext(), w, k) + if err != nil { return err } - log.Infof("Kernel save stats: %s", &stats) + log.Infof("Kernel save stats: %s", stats.String()) log.Infof("Kernel save took [%s].", time.Since(kernelStart)) // Save the memory file's state. @@ -629,7 +633,7 @@ func (ts *TaskSet) unregisterEpollWaiters() { } // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -640,7 +644,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { + if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -655,11 +659,11 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() - var stats state.Stats - if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { + stats, err := state.Load(k.SupervisorContext(), r, k) + if err != nil { return err } - log.Infof("Kernel load stats: %s", &stats) + log.Infof("Kernel load stats: %s", stats.String()) log.Infof("Kernel load took [%s].", time.Since(kernelStart)) // rootNetworkNamespace should be populated after loading the state file. @@ -1465,6 +1469,11 @@ func (k *Kernel) NowMonotonic() int64 { return now } +// AfterFunc implements tcpip.Clock.AfterFunc. +func (k *Kernel) AfterFunc(d time.Duration, f func()) tcpip.Timer { + return ktime.TcpipAfterFunc(k.realtimeClock, d, f) +} + // SetMemoryFile sets Kernel.mf. SetMemoryFile must be called before Init or // LoadFrom. func (k *Kernel) SetMemoryFile(mf *pgalloc.MemoryFile) { diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index a4519363f..45d4c5fc1 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -200,6 +200,11 @@ func (fd *VFSPipeFD) Readiness(mask waiter.EventMask) waiter.EventMask { } } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *VFSPipeFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ESPIPE +} + // EventRegister implements waiter.Waitable.EventRegister. func (fd *VFSPipeFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { fd.pipe.EventRegister(e, mask) diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index bfd779837..c211fc8d0 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -20,7 +20,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/pgalloc", - "//pkg/sentry/platform", "//pkg/sentry/usage", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index f66cfcc7f..55b4c2cdb 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -45,7 +45,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -370,7 +369,7 @@ type Shm struct { // fr is the offset into mfp.MemoryFile() that backs this contents of this // segment. Immutable. - fr platform.FileRange + fr memmap.FileRange // mu protects all fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index 4607cde2f..a83ce219c 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -98,6 +98,15 @@ func (s *syslog) Log() []byte { s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, selectMessage()))...) } + if VFS2Enabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up VFS2..."))...) + if FUSEEnabled { + time += rand.Float64() / 2 + s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Setting up FUSE..."))...) + } + } + time += rand.Float64() / 2 s.msg = append(s.msg, []byte(fmt.Sprintf(format, time, "Ready!"))...) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f48247c94..c4db05bd8 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -68,6 +68,21 @@ type Task struct { // runState is exclusive to the task goroutine. runState taskRunState + // taskWorkCount represents the current size of the task work queue. It is + // used to avoid acquiring taskWorkMu when the queue is empty. + // + // Must accessed with atomic memory operations. + taskWorkCount int32 + + // taskWorkMu protects taskWork. + taskWorkMu sync.Mutex `state:"nosave"` + + // taskWork is a queue of work to be executed before resuming user execution. + // It is similar to the task_work mechanism in Linux. + // + // taskWork is exclusive to the task goroutine. + taskWork []TaskWorker + // haveSyscallReturn is true if tc.Arch().Return() represents a value // returned by a syscall (or set by ptrace after a syscall). // @@ -550,6 +565,10 @@ type Task struct { // futexWaiter is exclusive to the task goroutine. futexWaiter *futex.Waiter `state:"nosave"` + // robustList is a pointer to the head of the tasks's robust futex + // list. + robustList usermem.Addr + // startTime is the real time at which the task started. It is set when // a Task is created or invokes execve(2). // diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 9b69f3cbe..7803b98d0 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -207,6 +207,9 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { return flags.CloseOnExec }) + // Handle the robust futex list. + t.exitRobustList() + // NOTE(b/30815691): We currently do not implement privileged // executables (set-user/group-ID bits and file capabilities). This // allows us to unconditionally enable user dumpability on the new mm. diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c4ade6e8e..231ac548a 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -253,6 +253,9 @@ func (*runExitMain) execute(t *Task) taskRunState { } } + // Handle the robust futex list. + t.exitRobustList() + // Deactivate the address space and update max RSS before releasing the // task's MM. t.Deactivate() diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index a53e77c9f..4b535c949 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -15,6 +15,7 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -52,3 +53,127 @@ func (t *Task) LoadUint32(addr usermem.Addr) (uint32, error) { func (t *Task) GetSharedKey(addr usermem.Addr) (futex.Key, error) { return t.MemoryManager().GetSharedFutexKey(t, addr) } + +// GetRobustList sets the robust futex list for the task. +func (t *Task) GetRobustList() usermem.Addr { + t.mu.Lock() + addr := t.robustList + t.mu.Unlock() + return addr +} + +// SetRobustList sets the robust futex list for the task. +func (t *Task) SetRobustList(addr usermem.Addr) { + t.mu.Lock() + t.robustList = addr + t.mu.Unlock() +} + +// exitRobustList walks the robust futex list, marking locks dead and notifying +// wakers. It corresponds to Linux's exit_robust_list(). Following Linux, +// errors are silently ignored. +func (t *Task) exitRobustList() { + t.mu.Lock() + addr := t.robustList + t.robustList = 0 + t.mu.Unlock() + + if addr == 0 { + return + } + + var rl linux.RobustListHead + if _, err := rl.CopyIn(t, usermem.Addr(addr)); err != nil { + return + } + + next := rl.List + done := 0 + var pendingLockAddr usermem.Addr + if rl.ListOpPending != 0 { + pendingLockAddr = usermem.Addr(rl.ListOpPending + rl.FutexOffset) + } + + // Wake up normal elements. + for usermem.Addr(next) != addr { + // We traverse to the next element of the list before we + // actually wake anything. This prevents the race where waking + // this futex causes a modification of the list. + thisLockAddr := usermem.Addr(next + rl.FutexOffset) + + // Try to decode the next element in the list before waking the + // current futex. But don't check the error until after we've + // woken the current futex. Linux does it in this order too + _, nextErr := t.CopyIn(usermem.Addr(next), &next) + + // Wakeup the current futex if it's not pending. + if thisLockAddr != pendingLockAddr { + t.wakeRobustListOne(thisLockAddr) + } + + // If there was an error copying the next futex, we must bail. + if nextErr != nil { + break + } + + // This is a user structure, so it could be a massive list, or + // even contain a loop if they are trying to mess with us. We + // cap traversal to prevent that. + done++ + if done >= linux.ROBUST_LIST_LIMIT { + break + } + } + + // Is there a pending entry to wake? + if pendingLockAddr != 0 { + t.wakeRobustListOne(pendingLockAddr) + } +} + +// wakeRobustListOne wakes a single futex from the robust list. +func (t *Task) wakeRobustListOne(addr usermem.Addr) { + // Bit 0 in address signals PI futex. + pi := addr&1 == 1 + addr = addr &^ 1 + + // Load the futex. + f, err := t.LoadUint32(addr) + if err != nil { + // Can't read this single value? Ignore the problem. + // We can wake the other futexes in the list. + return + } + + tid := uint32(t.ThreadID()) + for { + // Is this held by someone else? + if f&linux.FUTEX_TID_MASK != tid { + return + } + + // This thread is dying and it's holding this futex. We need to + // set the owner died bit and wake up any waiters. + newF := (f & linux.FUTEX_WAITERS) | linux.FUTEX_OWNER_DIED + if curF, err := t.CompareAndSwapUint32(addr, f, newF); err != nil { + return + } else if curF != f { + // Futex changed out from under us. Try again... + f = curF + continue + } + + // Wake waiters if there are any. + if f&linux.FUTEX_WAITERS != 0 { + private := f&linux.FUTEX_PRIVATE_FLAG != 0 + if pi { + t.Futex().UnlockPI(t, addr, tid, private) + return + } + t.Futex().Wake(t, addr, private, linux.FUTEX_BITSET_MATCH_ANY, 1) + } + + // Done. + return + } +} diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index d654dd997..7d4f44caf 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -167,7 +167,22 @@ func (app *runApp) execute(t *Task) taskRunState { return (*runInterrupt)(nil) } - // We're about to switch to the application again. If there's still a + // Execute any task work callbacks before returning to user space. + if atomic.LoadInt32(&t.taskWorkCount) > 0 { + t.taskWorkMu.Lock() + queue := t.taskWork + t.taskWork = nil + atomic.StoreInt32(&t.taskWorkCount, 0) + t.taskWorkMu.Unlock() + + // Do not hold taskWorkMu while executing task work, which may register + // more work. + for _, work := range queue { + work.TaskWork(t) + } + } + + // We're about to switch to the application again. If there's still an // unhandled SyscallRestartErrno that wasn't translated to an EINTR, // restart the syscall that was interrupted. If there's a saved signal // mask, restore it. (Note that restoring the saved signal mask may unblock diff --git a/pkg/sentry/kernel/task_work.go b/pkg/sentry/kernel/task_work.go new file mode 100644 index 000000000..dda5a433a --- /dev/null +++ b/pkg/sentry/kernel/task_work.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import "sync/atomic" + +// TaskWorker is a deferred task. +// +// This must be savable. +type TaskWorker interface { + // TaskWork will be executed prior to returning to user space. Note that + // TaskWork may call RegisterWork again, but this will not be executed until + // the next return to user space, unlike in Linux. This effectively allows + // registration of indefinite user return hooks, but not by default. + TaskWork(t *Task) +} + +// RegisterWork can be used to register additional task work that will be +// performed prior to returning to user space. See TaskWorker.TaskWork for +// semantics regarding registration. +func (t *Task) RegisterWork(work TaskWorker) { + t.taskWorkMu.Lock() + defer t.taskWorkMu.Unlock() + atomic.AddInt32(&t.taskWorkCount, 1) + t.taskWork = append(t.taskWork, work) +} diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 52849f5b3..4dfd2c990 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -366,7 +366,8 @@ func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error { // terminal is stolen, and all processes that had it as controlling // terminal lose it." - tty_ioctl(4) if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session { - if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 { + // Stealing requires CAP_SYS_ADMIN in the root user namespace. + if creds := auth.CredentialsFromContext(tg.leader); !creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, creds.UserNamespace.Root()) || arg != 1 { return syserror.EPERM } // Steal the TTY away. Unlike TIOCNOTTY, don't send signals. diff --git a/pkg/sentry/kernel/time/BUILD b/pkg/sentry/kernel/time/BUILD index 7ba7dc50c..2817aa3ba 100644 --- a/pkg/sentry/kernel/time/BUILD +++ b/pkg/sentry/kernel/time/BUILD @@ -6,6 +6,7 @@ go_library( name = "time", srcs = [ "context.go", + "tcpip.go", "time.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/kernel/time/tcpip.go b/pkg/sentry/kernel/time/tcpip.go new file mode 100644 index 000000000..c4474c0cf --- /dev/null +++ b/pkg/sentry/kernel/time/tcpip.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package time + +import ( + "sync" + "time" +) + +// TcpipAfterFunc waits for duration to elapse according to clock then runs fn. +// The timer is started immediately and will fire exactly once. +func TcpipAfterFunc(clock Clock, duration time.Duration, fn func()) *TcpipTimer { + timer := &TcpipTimer{ + clock: clock, + } + timer.notifier = functionNotifier{ + fn: func() { + // tcpip.Timer.Stop() explicitly states that the function is called in a + // separate goroutine that Stop() does not synchronize with. + // Timer.Destroy() synchronizes with calls to TimerListener.Notify(). + // This is semantically meaningful because, in the former case, it's + // legal to call tcpip.Timer.Stop() while holding locks that may also be + // taken by the function, but this isn't so in the latter case. Most + // immediately, Timer calls TimerListener.Notify() while holding + // Timer.mu. A deadlock occurs without spawning a goroutine: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, deadlock! + // + // Spawning a goroutine avoids the deadlock: + // T1: (Timer expires) + // => Timer.Tick() <- Timer.mu.Lock() called + // => TimerListener.Notify() <- Launches T2 + // T2: + // => Timer.Stop() + // => Timer.Destroy() <- Timer.mu.Lock() called, blocks + // T1: + // => (returns) <- Timer.mu.Unlock() called + // T2: + // => (continues) <- No deadlock! + go func() { + timer.Stop() + fn() + }() + }, + } + timer.Reset(duration) + return timer +} + +// TcpipTimer is a resettable timer with variable duration expirations. +// Implements tcpip.Timer, which does not define a Destroy method; instead, all +// resources are released after timer expiration and calls to Timer.Stop. +// +// Must be created by AfterFunc. +type TcpipTimer struct { + // clock is the time source. clock is immutable. + clock Clock + + // notifier is called when the Timer expires. notifier is immutable. + notifier functionNotifier + + // mu protects t. + mu sync.Mutex + + // t stores the latest running Timer. This is replaced whenever Reset is + // called since Timer cannot be restarted once it has been Destroyed by Stop. + // + // This field is nil iff Stop has been called. + t *Timer +} + +// Stop implements tcpip.Timer.Stop. +func (r *TcpipTimer) Stop() bool { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + return false + } + _, lastSetting := r.t.Swap(Setting{}) + r.t.Destroy() + r.t = nil + return lastSetting.Enabled +} + +// Reset implements tcpip.Timer.Reset. +func (r *TcpipTimer) Reset(d time.Duration) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.t == nil { + r.t = NewTimer(r.clock, &r.notifier) + } + + r.t.Swap(Setting{ + Enabled: true, + Period: 0, + Next: r.clock.Now().Add(d), + }) +} + +// functionNotifier is a TimerListener that runs a function. +// +// functionNotifier cannot be saved or loaded. +type functionNotifier struct { + fn func() +} + +// Notify implements ktime.TimerListener.Notify. +func (f *functionNotifier) Notify(uint64, Setting) (Setting, bool) { + f.fn() + return Setting{}, false +} + +// Destroy implements ktime.TimerListener.Destroy. +func (f *functionNotifier) Destroy() {} diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go index 0adf25691..7c4fefb16 100644 --- a/pkg/sentry/kernel/timekeeper.go +++ b/pkg/sentry/kernel/timekeeper.go @@ -21,8 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/log" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" sentrytime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" ) @@ -90,7 +90,7 @@ type Timekeeper struct { // NewTimekeeper does not take ownership of paramPage. // // SetClocks must be called on the returned Timekeeper before it is usable. -func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) { +func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) { return &Timekeeper{ params: NewVDSOParamPage(mfp, paramPage), }, nil @@ -210,9 +210,6 @@ func (t *Timekeeper) startUpdater() { p.realtimeBaseRef = int64(realtimeParams.BaseRef) p.realtimeFrequency = realtimeParams.Frequency } - - log.Debugf("Updating VDSO parameters: %+v", p) - return p }); err != nil { log.Warningf("Unable to update VDSO parameter page: %v", err) diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index f1b3c212c..290c32466 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -19,8 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/usermem" ) @@ -58,7 +58,7 @@ type vdsoParams struct { type VDSOParamPage struct { // The parameter page is fr, allocated from mfp.MemoryFile(). mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange // seq is the current sequence count written to the page. // @@ -81,7 +81,7 @@ type VDSOParamPage struct { // * VDSOParamPage must be the only writer to fr. // // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. -func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage { +func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { return &VDSOParamPage{mfp: mfp, fr: fr} } diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index c6aa65f28..34bdb0b69 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -30,9 +30,6 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/fs", - "//pkg/sentry/fs/anon", - "//pkg/sentry/fs/fsutil", "//pkg/sentry/fsbridge", "//pkg/sentry/kernel/auth", "//pkg/sentry/limits", @@ -45,6 +42,5 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/usermem", - "//pkg/waiter", ], ) diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 616fafa2c..ddeaff3db 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -90,14 +90,23 @@ type elfInfo struct { sharedObject bool } +// fullReader interface extracts the ReadFull method from fsbridge.File so that +// client code does not need to define an entire fsbridge.File when only read +// functionality is needed. +// +// TODO(gvisor.dev/issue/1035): Once VFS2 ships, rewrite this to wrap +// vfs.FileDescription's PRead/Read instead. +type fullReader interface { + // ReadFull is the same as fsbridge.File.ReadFull. + ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) +} + // parseHeader parse the ELF header, verifying that this is a supported ELF // file and returning the ELF program headers. // // This is similar to elf.NewFile, except that it is more strict about what it // accepts from the ELF, and it doesn't parse unnecessary parts of the file. -// -// ctx may be nil if f does not need it. -func parseHeader(ctx context.Context, f fsbridge.File) (elfInfo, error) { +func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // Check ident first; it will tell us the endianness of the rest of the // structs. var ident [elf.EI_NIDENT]byte diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 88449fe95..986c7fb4d 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/mm" @@ -80,22 +79,6 @@ type LoadArgs struct { Features *cpuid.FeatureSet } -// readFull behaves like io.ReadFull for an *fs.File. -func readFull(ctx context.Context, f *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var total int64 - for dst.NumBytes() > 0 { - n, err := f.Preadv(ctx, dst, offset+total) - total += n - if err == io.EOF && total != 0 { - return total, io.ErrUnexpectedEOF - } else if err != nil { - return total, err - } - dst = dst.DropFirst64(n) - } - return total, nil -} - // openPath opens args.Filename and checks that it is valid for loading. // // openPath returns an *fs.Dirent and *fs.File for args.Filename, which is not @@ -238,14 +221,14 @@ func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *V // Load the executable itself. loaded, ac, file, newArgv, err := loadExecutable(ctx, args) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("failed to load %s: %v", args.Filename, err), syserr.FromError(err).ToLinux()) } defer file.DecRef() // Load the VDSO. vdsoAddr, err := loadVDSO(ctx, args.MemoryManager, vdso, loaded) if err != nil { - return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("Error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) + return 0, nil, "", syserr.NewDynamic(fmt.Sprintf("error loading VDSO: %v", err), syserr.FromError(err).ToLinux()) } // Setup the heap. brk starts at the next page after the end of the diff --git a/pkg/sentry/loader/vdso.go b/pkg/sentry/loader/vdso.go index 165869028..05a294fe6 100644 --- a/pkg/sentry/loader/vdso.go +++ b/pkg/sentry/loader/vdso.go @@ -26,10 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/anon" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -37,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/pkg/waiter" ) const vdsoPrelink = 0xffffffffff700000 @@ -55,52 +50,11 @@ func (f *fileContext) Value(key interface{}) interface{} { } } -// byteReader implements fs.FileOperations for reading from a []byte source. -type byteReader struct { - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoIoctl `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileNoopRelease `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - waiter.AlwaysReady `state:"nosave"` - +type byteFullReader struct { data []byte } -var _ fs.FileOperations = (*byteReader)(nil) - -// newByteReaderFile creates a fake file to read data from. -// -// TODO(gvisor.dev/issue/2921): Convert to VFS2. -func newByteReaderFile(ctx context.Context, data []byte) *fs.File { - // Create a fake inode. - inode := fs.NewInode( - ctx, - &fsutil.SimpleFileInode{}, - fs.NewPseudoMountSource(ctx), - fs.StableAttr{ - Type: fs.Anonymous, - DeviceID: anon.PseudoDevice.DeviceID(), - InodeID: anon.PseudoDevice.NextIno(), - BlockSize: usermem.PageSize, - }) - - // Use the fake inode to create a fake dirent. - dirent := fs.NewTransientDirent(inode) - defer dirent.DecRef() - - // Use the fake dirent to make a fake file. - flags := fs.FileFlags{Read: true, Pread: true} - return fs.NewFile(&fileContext{Context: context.Background()}, dirent, flags, &byteReader{ - data: data, - }) -} - -func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { +func (b *byteFullReader) ReadFull(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { if offset < 0 { return 0, syserror.EINVAL } @@ -111,10 +65,6 @@ func (b *byteReader) Read(ctx context.Context, file *fs.File, dst usermem.IOSequ return int64(n), err } -func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) { - panic("Write not supported") -} - // validateVDSO checks that the VDSO can be loaded by loadVDSO. // // VDSOs are special (see below). Since we are going to map the VDSO directly @@ -130,7 +80,7 @@ func (b *byteReader) Write(ctx context.Context, file *fs.File, src usermem.IOSeq // * PT_LOAD segments don't extend beyond the end of the file. // // ctx may be nil if f does not need it. -func validateVDSO(ctx context.Context, f fsbridge.File, size uint64) (elfInfo, error) { +func validateVDSO(ctx context.Context, f fullReader, size uint64) (elfInfo, error) { info, err := parseHeader(ctx, f) if err != nil { log.Infof("Unable to parse VDSO header: %v", err) @@ -248,13 +198,12 @@ func getSymbolValueFromVDSO(symbol string) (uint64, error) { // PrepareVDSO validates the system VDSO and returns a VDSO, containing the // param page for updating by the kernel. -func PrepareVDSO(ctx context.Context, mfp pgalloc.MemoryFileProvider) (*VDSO, error) { - vdsoFile := fsbridge.NewFSFile(newByteReaderFile(ctx, vdsoBin)) +func PrepareVDSO(mfp pgalloc.MemoryFileProvider) (*VDSO, error) { + vdsoFile := &byteFullReader{data: vdsoBin} // First make sure the VDSO is valid. vdsoFile does not use ctx, so a // nil context can be passed. info, err := validateVDSO(nil, vdsoFile, uint64(len(vdsoBin))) - vdsoFile.DecRef() if err != nil { return nil, err } diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD index a98b66de1..2c95669cd 100644 --- a/pkg/sentry/memmap/BUILD +++ b/pkg/sentry/memmap/BUILD @@ -28,9 +28,21 @@ go_template_instance( }, ) +go_template_instance( + name = "file_range", + out = "file_range.go", + package = "memmap", + prefix = "File", + template = "//pkg/segment:generic_range", + types = { + "T": "uint64", + }, +) + go_library( name = "memmap", srcs = [ + "file_range.go", "mappable_range.go", "mapping_set.go", "mapping_set_impl.go", @@ -40,7 +52,7 @@ go_library( deps = [ "//pkg/context", "//pkg/log", - "//pkg/sentry/platform", + "//pkg/safemem", "//pkg/syserror", "//pkg/usermem", ], diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index c6db9fc8f..c188f6c29 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -19,12 +19,12 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/usermem" ) // Mappable represents a memory-mappable object, a mutable mapping from uint64 -// offsets to (platform.File, uint64 File offset) pairs. +// offsets to (File, uint64 File offset) pairs. // // See mm/mm.go for Mappable's place in the lock order. // @@ -74,7 +74,7 @@ type Mappable interface { // Translations are valid until invalidated by a callback to // MappingSpace.Invalidate or until the caller removes its mapping of the // translated range. Mappable implementations must ensure that at least one - // reference is held on all pages in a platform.File that may be the result + // reference is held on all pages in a File that may be the result // of a valid Translation. // // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). @@ -100,7 +100,7 @@ type Translation struct { Source MappableRange // File is the mapped file. - File platform.File + File File // Offset is the offset into File at which this Translation begins. Offset uint64 @@ -110,9 +110,9 @@ type Translation struct { Perms usermem.AccessType } -// FileRange returns the platform.FileRange represented by t. -func (t Translation) FileRange() platform.FileRange { - return platform.FileRange{t.Offset, t.Offset + t.Source.Length()} +// FileRange returns the FileRange represented by t. +func (t Translation) FileRange() FileRange { + return FileRange{t.Offset, t.Offset + t.Source.Length()} } // CheckTranslateResult returns an error if (ts, terr) does not satisfy all @@ -361,3 +361,49 @@ type MMapOpts struct { // TODO(jamieliu): Replace entirely with MappingIdentity? Hint string } + +// File represents a host file that may be mapped into an platform.AddressSpace. +type File interface { + // All pages in a File are reference-counted. + + // IncRef increments the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) + IncRef(fr FileRange) + + // DecRef decrements the reference count on all pages in fr. + // + // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > + // 0. At least one reference must be held on all pages in fr. + DecRef(fr FileRange) + + // MapInternal returns a mapping of the given file offsets in the invoking + // process' address space for reading and writing. + // + // Note that fr.Start and fr.End need not be page-aligned. + // + // Preconditions: fr.Length() > 0. At least one reference must be held on + // all pages in fr. + // + // Postconditions: The returned mapping is valid as long as at least one + // reference is held on the mapped pages. + MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) + + // FD returns the file descriptor represented by the File. + // + // The only permitted operation on the returned file descriptor is to map + // pages from it consistent with the requirements of AddressSpace.MapFile. + FD() int +} + +// FileRange represents a range of uint64 offsets into a File. +// +// type FileRange <generated using go_generics> + +// String implements fmt.Stringer.String. +func (fr FileRange) String() string { + return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) +} diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index a036ce53c..f9d0837a1 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -7,14 +7,14 @@ go_template_instance( name = "file_refcount_set", out = "file_refcount_set.go", imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "mm", prefix = "fileRefcount", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "int32", "Functions": "fileRefcountSetFunctions", }, diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 379148903..1999ec706 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -243,7 +242,7 @@ type aioMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange } var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp()) diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 6db7c3d40..3e85964e4 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -25,7 +25,7 @@ // Locks taken by memmap.Mappable.Translate // mm.privateRefs.mu // platform.AddressSpace locks -// platform.File locks +// memmap.File locks // mm.aioManager.mu // mm.AIOContext.mu // @@ -396,7 +396,7 @@ type pma struct { // file is the file mapped by this pma. Only pmas for which file == // MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to // the corresponding file range while they exist. - file platform.File `state:"nosave"` + file memmap.File `state:"nosave"` // off is the offset into file at which this pma begins. // @@ -436,7 +436,7 @@ type pma struct { private bool // If internalMappings is not empty, it is the cached return value of - // file.MapInternal for the platform.FileRange mapped by this pma. + // file.MapInternal for the memmap.FileRange mapped by this pma. internalMappings safemem.BlockSeq `state:"nosave"` } @@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 { func (fileRefcountSetFunctions) ClearValue(_ *int32) { } -func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) { +func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) { return rc1, rc1 == rc2 } -func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) { +func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) { return rc, rc } diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 62e4c20af..930ec895f 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -21,7 +21,6 @@ import ( "gvisor.dev/gvisor/pkg/safecopy" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat } } -// Pin returns the platform.File ranges currently mapped by addresses in ar in +// Pin returns the memmap.File ranges currently mapped by addresses in ar in // mm, acquiring a reference on the returned ranges which the caller must // release by calling Unpin. If not all addresses are mapped, Pin returns a // non-nil error. Note that Pin may return both a non-empty slice of @@ -674,15 +673,15 @@ type PinnedRange struct { Source usermem.AddrRange // File is the mapped file. - File platform.File + File memmap.File // Offset is the offset into File at which this PinnedRange begins. Offset uint64 } -// FileRange returns the platform.File offsets mapped by pr. -func (pr PinnedRange) FileRange() platform.FileRange { - return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} +// FileRange returns the memmap.File offsets mapped by pr. +func (pr PinnedRange) FileRange() memmap.FileRange { + return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())} } // Unpin releases the reference held by prs. @@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf } // incPrivateRef acquires a reference on private pages in fr. -func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { +func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) { mm.privateRefs.mu.Lock() defer mm.privateRefs.mu.Unlock() refSet := &mm.privateRefs.refs @@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) { } // decPrivateRef releases a reference on private pages in fr. -func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) { - var freed []platform.FileRange +func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) { + var freed []memmap.FileRange mm.privateRefs.mu.Lock() refSet := &mm.privateRefs.refs @@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa // Discard internal mappings instead of trying to merge them, since merging // them requires an allocation and getting them again from the - // platform.File might not. + // memmap.File might not. pma1.internalMappings = safemem.BlockSeq{} return pma1, true } @@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error { return nil } -func (pseg pmaIterator) fileRange() platform.FileRange { +func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } // Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. -func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { +func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { panic("terminal pma iterator") @@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange { pma := pseg.ValuePtr() pstart := pseg.Start() - return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} + return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)} } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 9ad52082d..0e142fb11 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" - "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -35,7 +34,7 @@ type SpecialMappable struct { refs.AtomicRefCount mfp pgalloc.MemoryFileProvider - fr platform.FileRange + fr memmap.FileRange name string } @@ -44,7 +43,7 @@ type SpecialMappable struct { // SpecialMappable will use the given name in /proc/[pid]/maps. // // Preconditions: fr.Length() != 0. -func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable { +func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} m.EnableLeakCheck("mm.SpecialMappable") return &m @@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider { // FileRange returns the offsets into MemoryFileProvider().MemoryFile() that // store the SpecialMappable's contents. -func (m *SpecialMappable) FileRange() platform.FileRange { +func (m *SpecialMappable) FileRange() memmap.FileRange { return m.fr } diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index a9836ba71..7a3311a70 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -36,14 +36,14 @@ go_template_instance( "trackGaps": "1", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "usage", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "usageInfo", "Functions": "usageSetFunctions", }, @@ -56,14 +56,14 @@ go_template_instance( "minDegree": "10", }, imports = { - "platform": "gvisor.dev/gvisor/pkg/sentry/platform", + "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap", }, package = "pgalloc", prefix = "reclaim", template = "//pkg/segment:generic_set", types = { "Key": "uint64", - "Range": "platform.FileRange", + "Range": "memmap.FileRange", "Value": "reclaimSetValue", "Functions": "reclaimSetFunctions", }, @@ -89,9 +89,10 @@ go_library( "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/hostmm", - "//pkg/sentry/platform", + "//pkg/sentry/memmap", "//pkg/sentry/usage", "//pkg/state", + "//pkg/state/wire", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index afab97c0a..3243d7214 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -33,14 +33,14 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/hostmm" - "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) -// MemoryFile is a platform.File whose pages may be allocated to arbitrary +// MemoryFile is a memmap.File whose pages may be allocated to arbitrary // users. type MemoryFile struct { // opts holds options passed to NewMemoryFile. opts is immutable. @@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() { // to Allocate. // // Preconditions: length must be page-aligned and non-zero. -func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) { +func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) { if length == 0 || length%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid allocation length: %#x", length)) } @@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Find a range in the underlying file. fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment) if !ok { - return platform.FileRange{}, syserror.ENOMEM + return memmap.FileRange{}, syserror.ENOMEM } // Expand the file if needed. @@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // Round the new file size up to be chunk-aligned. newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask if err := f.file.Truncate(newFileSize); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } f.fileSize = newFileSize f.mappingsMu.Lock() @@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi bs[i] = 0 } }); err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } } if !f.usage.Add(fr, usageInfo{ @@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi // space for mappings to be allocated downwards. // // Precondition: alignment must be a power of 2. -func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) { +func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) { alignmentMask := alignment - 1 // Search for space in existing gaps, starting at the current end of the @@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 break } if start := unalignedStart &^ alignmentMask; start >= gap.Start() { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } gap = gap.PrevLargeEnoughGap(length) @@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 min = (min + alignmentMask) &^ alignmentMask if min+length < min { // Overflow: allocation would exceed the range of uint64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } // Determine the minimum file size required to fit this allocation at its end. @@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 if newFileSize <= fileSize { if fileSize != 0 { // Overflow: allocation would exceed the range of int64. - return platform.FileRange{}, false + return memmap.FileRange{}, false } newFileSize = chunkSize } @@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 continue } if start := unalignedStart &^ alignmentMask; start >= min { - return platform.FileRange{start, start + length}, true + return memmap.FileRange{start, start + length}, true } } } @@ -508,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // by r.ReadToBlocks(), it returns that error. // // Preconditions: length > 0. length must be page-aligned. -func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) { +func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { - return platform.FileRange{}, err + return memmap.FileRange{}, err } dsts, err := f.MapInternal(fr, usermem.Write) if err != nil { f.DecRef(fr) - return platform.FileRange{}, err + return memmap.FileRange{}, err } n, err := safemem.ReadFullToBlocks(r, dsts) un := uint64(usermem.Addr(n).RoundDown()) if un < length { // Free unused memory and update fr to contain only the memory that is // still allocated. - f.DecRef(platform.FileRange{fr.Start + un, fr.End}) + f.DecRef(memmap.FileRange{fr.Start + un, fr.End}) fr.End = fr.Start + un } return fr, err @@ -540,7 +540,7 @@ const ( // will read zeroes. // // Preconditions: fr.Length() > 0. -func (f *MemoryFile) Decommit(fr platform.FileRange) error { +func (f *MemoryFile) Decommit(fr memmap.FileRange) error { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -560,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error { return nil } -func (f *MemoryFile) markDecommitted(fr platform.FileRange) { +func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() // Since we're changing the knownCommitted attribute, we need to merge @@ -581,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) { f.usage.MergeRange(fr) } -// IncRef implements platform.File.IncRef. -func (f *MemoryFile) IncRef(fr platform.FileRange) { +// IncRef implements memmap.File.IncRef. +func (f *MemoryFile) IncRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -600,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) { f.usage.MergeAdjacent(fr) } -// DecRef implements platform.File.DecRef. -func (f *MemoryFile) DecRef(fr platform.FileRange) { +// DecRef implements memmap.File.DecRef. +func (f *MemoryFile) DecRef(fr memmap.FileRange) { if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -637,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) { } } -// MapInternal implements platform.File.MapInternal. -func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { +// MapInternal implements memmap.File.MapInternal. +func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { if !fr.WellFormed() || fr.Length() == 0 { panic(fmt.Sprintf("invalid range: %v", fr)) } @@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) ( // forEachMappingSlice invokes fn on a sequence of byte slices that // collectively map all bytes in fr. -func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error { +func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error { mappings := f.mappings.Load().([]uintptr) for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize { chunk := int(chunkStart >> chunkShift) @@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( continue case !populated && populatedRun: // Finish the run by changing this segment. - runRange := platform.FileRange{ + runRange := memmap.FileRange{ Start: r.Start + uint64(populatedRunStart*usermem.PageSize), End: r.Start + uint64(i*usermem.PageSize), } @@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File { return f.file } -// FD implements platform.File.FD. +// FD implements memmap.File.FD. func (f *MemoryFile) FD() int { return int(f.file.Fd()) } @@ -1090,13 +1090,13 @@ func (f *MemoryFile) runReclaim() { // // Note that there returned range will be removed from tracking. It // must be reclaimed (removed from f.usage) at this point. -func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { +func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) { f.mu.Lock() defer f.mu.Unlock() for { for { if f.destroyed { - return platform.FileRange{}, false + return memmap.FileRange{}, false } if f.reclaimable { break @@ -1120,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) { } } -func (f *MemoryFile) markReclaimed(fr platform.FileRange) { +func (f *MemoryFile) markReclaimed(fr memmap.FileRange) { f.mu.Lock() defer f.mu.Unlock() seg := f.usage.FindSegment(fr.Start) @@ -1222,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 { func (usageSetFunctions) ClearValue(val *usageInfo) { } -func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) { +func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) { return val1, val1 == val2 } -func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { +func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) { return val, val } @@ -1270,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 { func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) { } -func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { +func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) { return reclaimSetValue{}, true } -func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { +func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) { return reclaimSetValue{}, reclaimSetValue{} } diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index f8385c146..78317fa35 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -26,11 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" "gvisor.dev/gvisor/pkg/usermem" ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w wire.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -79,10 +80,10 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // Save metadata. - if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { + if _, err := state.Save(ctx, w, &f.fileSize); err != nil { return err } - if err := state.Save(ctx, w, &f.usage, nil); err != nil { + if _, err := state.Save(ctx, w, &f.usage); err != nil { return err } @@ -115,9 +116,9 @@ func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r wire.Reader) error { // Load metadata. - if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { + if _, err := state.Load(ctx, r, &f.fileSize); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -125,7 +126,7 @@ func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(ctx, r, &f.usage, nil); err != nil { + if _, err := state.Load(ctx, r, &f.usage); err != nil { return err } diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD index 453241eca..209b28053 100644 --- a/pkg/sentry/platform/BUILD +++ b/pkg/sentry/platform/BUILD @@ -1,39 +1,21 @@ load("//tools:defs.bzl", "go_library") -load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) -go_template_instance( - name = "file_range", - out = "file_range.go", - package = "platform", - prefix = "File", - template = "//pkg/segment:generic_range", - types = { - "T": "uint64", - }, -) - go_library( name = "platform", srcs = [ "context.go", - "file_range.go", "mmap_min_addr.go", "platform.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/atomicbitops", "//pkg/context", - "//pkg/log", - "//pkg/safecopy", - "//pkg/safemem", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/usage", - "//pkg/syserror", + "//pkg/sentry/memmap", "//pkg/usermem", ], ) diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 4792454c4..b5d27a72a 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -47,6 +47,7 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sentry/platform/ring0", @@ -60,6 +61,7 @@ go_library( go_test( name = "kvm_test", srcs = [ + "kvm_amd64_test.go", "kvm_test.go", "virtual_map_test.go", ], diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index faf1d5e1c..98a3e539d 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/atomicbitops" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/sync" @@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem. } // MapFile implements platform.AddressSpace.MapFile. -func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { as.mu.Lock() defer as.mu.Unlock() diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0d1e83e6c..03a98512e 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -17,6 +17,7 @@ package kvm import ( + "syscall" "unsafe" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -60,3 +61,27 @@ func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { func getHypercallID(addr uintptr) int { return _KVM_HYPERCALL_MAX } + +// bluepillStopGuest is reponsible for injecting interrupt. +// +//go:nosplit +func bluepillStopGuest(c *vCPU) { + // Interrupt: we must have requested an interrupt + // window; set the interrupt line. + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_INTERRUPT, + uintptr(unsafe.Pointer(&bounce))); errno != 0 { + throw("interrupt injection failed") + } + // Clear previous injection request. + c.runData.requestInterruptWindow = 0 +} + +// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. +// +//go:nosplit +func bluepillReadyStopGuest(c *vCPU) bool { + return c.runData.readyForInterruptInjection != 0 +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 83643c602..dba563160 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -26,6 +26,17 @@ import ( var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL + + // vcpuSErr is the event of system error. + vcpuSErr = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 0, + pad: [6]uint8{0, 0, 0, 0, 0, 0}, + sErrEsr: 1, + }, + rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } ) // bluepillArchEnter is called during bluepillEnter. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index abd36f973..8b64f3a1e 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -17,6 +17,7 @@ package kvm import ( + "syscall" "unsafe" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -74,3 +75,23 @@ func getHypercallID(addr uintptr) int { return int(((addr) - arm64HypercallMMIOBase) >> 3) } } + +// bluepillStopGuest is reponsible for injecting sError. +// +//go:nosplit +func bluepillStopGuest(c *vCPU) { + if _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { + throw("sErr injection failed") + } +} + +// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection. +// +//go:nosplit +func bluepillReadyStopGuest(c *vCPU) bool { + return true +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index a5b9be36d..bf357de1a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -133,12 +133,12 @@ func bluepillHandler(context unsafe.Pointer) { // PIC, we can't inject an interrupt while they are // masked. We need to request a window if it's not // ready. - if c.runData.readyForInterruptInjection == 0 { - c.runData.requestInterruptWindow = 1 - continue // Rerun vCPU. - } else { + if bluepillReadyStopGuest(c) { // Force injection below; the vCPU is ready. c.runData.exitReason = _KVM_EXIT_IRQ_WINDOW_OPEN + } else { + c.runData.requestInterruptWindow = 1 + continue // Rerun vCPU. } case syscall.EFAULT: // If a fault is not serviceable due to the host @@ -217,17 +217,7 @@ func bluepillHandler(context unsafe.Pointer) { } } case _KVM_EXIT_IRQ_WINDOW_OPEN: - // Interrupt: we must have requested an interrupt - // window; set the interrupt line. - if _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_INTERRUPT, - uintptr(unsafe.Pointer(&bounce))); errno != 0 { - throw("interrupt injection failed") - } - // Clear previous injection request. - c.runData.requestInterruptWindow = 0 + bluepillStopGuest(c) case _KVM_EXIT_SHUTDOWN: c.die(bluepillArchContext(context), "shutdown") return diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go new file mode 100644 index 000000000..c0b4fd374 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -0,0 +1,51 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package kvm + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/sentry/platform/kvm/testutil" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" +) + +func TestSegments(t *testing.T) { + applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + testutil.SetTestSegments(regs) + for { + var si arch.SignalInfo + if _, err := c.SwitchToUser(ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: dummyFPState, + PageTables: pt, + FullRestore: true, + }, &si); err == platform.ErrContextInterrupt { + continue // Retry. + } else if err != nil { + t.Errorf("application segment check with full restore got unexpected error: %v", err) + } + if err := testutil.CheckTestSegments(regs); err != nil { + t.Errorf("application segment check with full restore failed: %v", err) + } + break // Done. + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 3134a076b..0b06a923a 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -46,6 +46,18 @@ type userRegs struct { fpRegs userFpsimdState } +type exception struct { + sErrPending uint8 + sErrHasEsr uint8 + pad [6]uint8 + sErrEsr uint64 +} + +type kvmVcpuEvents struct { + exception + rsvd [12]uint32 +} + // updateGlobalOnce does global initialization. It has to be called only once. func updateGlobalOnce(fd int) error { physicalInit() diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go index 6531bae1d..48ccf8474 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go @@ -22,7 +22,8 @@ import ( ) var ( - runDataSize int + runDataSize int + hasGuestPCID bool ) func updateSystemValues(fd int) error { @@ -33,6 +34,7 @@ func updateSystemValues(fd int) error { } // Save the data. runDataSize = int(sz) + hasGuestPCID = true // Success. return nil diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 6a7676468..3bf918446 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -35,6 +35,8 @@ const ( _KVM_GET_SUPPORTED_CPUID = 0xc008ae05 _KVM_SET_CPUID2 = 0x4008ae90 _KVM_SET_SIGNAL_MASK = 0x4004ae8b + _KVM_GET_VCPU_EVENTS = 0x8040ae9f + _KVM_SET_VCPU_EVENTS = 0x4040aea0 ) // KVM exit reasons. @@ -54,8 +56,10 @@ const ( // KVM capability options. const ( - _KVM_CAP_MAX_VCPUS = 0x42 - _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 + _KVM_CAP_MAX_VCPUS = 0x42 + _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 + _KVM_CAP_VCPU_EVENTS = 0x29 + _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e ) // KVM limits. diff --git a/pkg/sentry/platform/kvm/kvm_test.go b/pkg/sentry/platform/kvm/kvm_test.go index 6c8f4fa28..45b3180f1 100644 --- a/pkg/sentry/platform/kvm/kvm_test.go +++ b/pkg/sentry/platform/kvm/kvm_test.go @@ -262,30 +262,6 @@ func TestRegistersFault(t *testing.T) { }) } -func TestSegments(t *testing.T) { - applicationTest(t, true, testutil.TwiddleSegments, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { - testutil.SetTestSegments(regs) - for { - var si arch.SignalInfo - if _, err := c.SwitchToUser(ring0.SwitchOpts{ - Registers: regs, - FloatingPointState: dummyFPState, - PageTables: pt, - FullRestore: true, - }, &si); err == platform.ErrContextInterrupt { - continue // Retry. - } else if err != nil { - t.Errorf("application segment check with full restore got unexpected error: %v", err) - } - if err := testutil.CheckTestSegments(regs); err != nil { - t.Errorf("application segment check with full restore failed: %v", err) - } - break // Done. - } - return false - }) -} - func TestBounce(t *testing.T) { applicationTest(t, true, testutil.SpinLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { go func() { diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 48c834499..ff8c068c0 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" + "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" "gvisor.dev/gvisor/pkg/usermem" ) @@ -78,19 +79,6 @@ func (c *vCPU) initArchState() error { return err } - // sctlr_el1 - regGet.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.getOneRegister(®Get); err != nil { - return err - } - - dataGet |= (_SCTLR_M | _SCTLR_C | _SCTLR_I) - data = dataGet - reg.id = _KVM_ARM64_REGS_SCTLR_EL1 - if err := c.setOneRegister(®); err != nil { - return err - } - // tcr_el1 data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS reg.id = _KVM_ARM64_REGS_TCR_EL1 @@ -169,6 +157,14 @@ func (c *vCPU) initArchState() error { return err } + // Initialize the PCID database. + if hasGuestPCID { + // Note that NewPCIDs may return a nil table here, in which + // case we simply don't use PCID support (see below). In + // practice, this should not happen, however. + c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs) + } + c.floatingPointState = arch.NewFloatingPointData() return nil } @@ -247,6 +243,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) } + // Assign PCIDs. + if c.PCIDs != nil { + var requireFlushPCID bool // Force a flush? + switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables) + switchOpts.Flush = switchOpts.Flush || requireFlushPCID + } + var vector ring0.Vector ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0) c.SetTtbr0App(uintptr(ttbr0App)) @@ -273,8 +276,8 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) - case 0xaa: - return usermem.NoAccess, nil + case ring0.Vector(bounce): // ring0.VirtualizationException + return usermem.NoAccess, platform.ErrContextInterrupt default: return usermem.NoAccess, platform.ErrContextSignal } diff --git a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s index 0bebee852..07658144e 100644 --- a/pkg/sentry/platform/kvm/testutil/testutil_arm64.s +++ b/pkg/sentry/platform/kvm/testutil/testutil_arm64.s @@ -104,3 +104,9 @@ TEXT ·TwiddleRegsSyscall(SB),NOSPLIT,$0 TWIDDLE_REGS() SVC RET // never reached + +TEXT ·TwiddleRegsFault(SB),NOSPLIT,$0 + TWIDDLE_REGS() + // Branch to Register branches unconditionally to an address in <Rn>. + JMP (R4) // <=> br x4, must fault + RET // never reached diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index 171513f3f..4b13eec30 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -22,9 +22,9 @@ import ( "os" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/usermem" ) @@ -207,7 +207,7 @@ type AddressSpace interface { // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. // at.Any() == true. At least one reference must be held on all pages in // fr, and must continue to be held as long as pages are mapped. - MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error + MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // @@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string { return fmt.Sprintf("segmentation fault at %#x", f.Addr) } -// File represents a host file that may be mapped into an AddressSpace. -type File interface { - // All pages in a File are reference-counted. - - // IncRef increments the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) - IncRef(fr FileRange) - - // DecRef decrements the reference count on all pages in fr. - // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. - DecRef(fr FileRange) - - // MapInternal returns a mapping of the given file offsets in the invoking - // process' address space for reading and writing. - // - // Note that fr.Start and fr.End need not be page-aligned. - // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. - // - // Postconditions: The returned mapping is valid as long as at least one - // reference is held on the mapped pages. - MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error) - - // FD returns the file descriptor represented by the File. - // - // The only permitted operation on the returned file descriptor is to map - // pages from it consistent with the requirements of AddressSpace.MapFile. - FD() int -} - -// FileRange represents a range of uint64 offsets into a File. -// -// type FileRange <generated using go_generics> - -// String implements fmt.Stringer.String. -func (fr FileRange) String() string { - return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End) -} - // Requirements is used to specify platform specific requirements. type Requirements struct { // RequiresCurrentPIDNS indicates that the sandbox has to be started in the diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index 30402c2df..29fd23cc3 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,6 +30,7 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/hostcpu", + "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", "//pkg/sync", diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 2389423b0..c990f3454 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp } // MapFile implements platform.AddressSpace.MapFile. -func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error { +func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error { var flags int if precommit { flags |= syscall.MAP_POPULATE diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2bc5f3ecd..9fd02d628 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -40,6 +40,20 @@ #define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT) +// sctlr_el1: system control register el1. +#define SCTLR_M 1 << 0 +#define SCTLR_C 1 << 2 +#define SCTLR_I 1 << 12 +#define SCTLR_UCT 1 << 15 + +#define SCTLR_EL1_DEFAULT (SCTLR_M | SCTLR_C | SCTLR_I | SCTLR_UCT) + +// cntkctl_el1: counter-timer kernel control register el1. +#define CNTKCTL_EL0PCTEN 1 << 0 +#define CNTKCTL_EL0VCTEN 1 << 1 + +#define CNTKCTL_EL1_DEFAULT (CNTKCTL_EL0PCTEN | CNTKCTL_EL0VCTEN) + // Saves a register set. // // This is a macro because it may need to executed in contents where a stack is @@ -496,6 +510,14 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 IRQ_DISABLE + + // Init. + MOVD $SCTLR_EL1_DEFAULT, R1 + MSR R1, SCTLR_EL1 + + MOVD $CNTKCTL_EL1_DEFAULT, R1 + MSR R1, CNTKCTL_EL1 + MOVD R8, RSV_REG ORR $0xffff000000000000, RSV_REG, RSV_REG WORD $0xd518d092 //MSR R18, TPIDR_EL1 diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index ccacaea6b..fca3a5478 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -58,7 +58,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(UserFlagsClear) regs.Pstate |= UserFlagsSet + + SetTLS(regs.TPIDR_EL0) + kernelExitToEl0() + + regs.TPIDR_EL0 = GetTLS() + vector = c.vecCode // Perform the switch. diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c40c6d673..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,5 +20,6 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index ff81ea6e6..e76e498de 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a92aed2c9..532a1ea5d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -36,6 +36,8 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -319,7 +321,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } @@ -364,7 +366,8 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. @@ -708,6 +711,6 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) - socket.RegisterProviderVFS2(family, &socketProviderVFS2{}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index ad5f64799..8a1d52ebf 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -71,6 +71,7 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in DenyPWrite: true, UseDentryMetadata: true, }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) return nil, syserr.FromError(err) } return vfsfd, nil @@ -96,7 +97,12 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// PRead implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + +// PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index f7abe77d3..a9f0604ae 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -66,7 +66,7 @@ func nflog(format string, args ...interface{}) { func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } @@ -84,7 +84,7 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } @@ -145,7 +145,7 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ + Entry: linux.IPTEntry{ IP: linux.IPTIP{ Protocol: uint16(rule.Filter.Protocol), }, @@ -153,20 +153,20 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin TargetOffset: linux.SizeOfIPTEntry, }, } - copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst) - copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src) - copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) if rule.Filter.DstInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP } if rule.Filter.SrcInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP } if rule.Filter.OutputInterfaceInvert { - entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT } for _, matcher := range rule.Matchers { @@ -178,8 +178,8 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) } // Serialize and append the target. @@ -188,11 +188,11 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) } entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + entry.Entry.NextOffset += uint16(len(serialized)) nflog("convert to binary: adding entry: %+v", entry) - entries.Size += uint32(entry.NextOffset) + entries.Size += uint32(entry.Entry.NextOffset) entries.Entrytable = append(entries.Entrytable, entry) info.NumEntries++ } @@ -342,10 +342,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.TablenameFilter: + case stack.FilterTable: table = stack.EmptyFilterTable() - case stack.TablenameNat: - table = stack.EmptyNatTable() + case stack.NATTable: + table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument @@ -431,6 +431,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { for hook, _ := range replace.HookEntry { if table.ValidHooks()&(1<<hook) != 0 { hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset for offset, ruleIdx := range offsets { if offset == replace.HookEntry[hook] { table.BuiltinChains[hk] = ruleIdx @@ -456,8 +458,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Add the user chains. for ruleIdx, rule := range table.Rules { - target, ok := rule.Target.(stack.UserChainTarget) - if !ok { + if _, ok := rule.Target.(stack.UserChainTarget); !ok { continue } @@ -473,7 +474,6 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { nflog("user chain's first node must have no matchers") return syserr.ErrInvalidArgument } - table.UserChains[target.Name] = ruleIdx + 1 } // Set each jump to point to the appropriate rule. Right now they hold byte @@ -499,7 +499,10 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, // make sure all other chains point to ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue + } if !isUnconditionalAccept(table.Rules[ruleIdx]) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument @@ -512,9 +515,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - stk.IPTables().ReplaceTable(replace.Name.String(), table) - - return nil + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 84abe8d29..b91ba3ab3 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -30,6 +30,6 @@ type JumpTarget struct { } // Action implements stack.Target.Action. -func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrackTable, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index d5ca3ac56..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -36,6 +36,8 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 81f34c5a2..98ca7add0 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -38,6 +38,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -330,7 +332,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -340,24 +342,26 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var passcred int32 + var passcred primitive.Int32 if s.Passcred() { passcred = 1 } - return passcred, nil + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index ea6ebd0e2..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -51,6 +51,8 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 738277391..f86e6cd7a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" @@ -61,6 +62,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -191,6 +194,8 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -295,8 +300,9 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -417,7 +423,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -445,8 +451,21 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) @@ -893,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -903,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -939,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -954,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -964,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -997,7 +1016,7 @@ func boolToInt32(v bool) int32 { } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -1008,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -1018,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -1033,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1049,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -1065,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { @@ -1076,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { @@ -1087,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1095,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } if v == 0 { - return []byte{}, nil + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument @@ -1110,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam // interface was removed. return nil, syserr.ErrUnknownDevice } - return append([]byte(nic.Name), 0), nil + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { @@ -1121,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { @@ -1132,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1146,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1154,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1166,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1175,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { @@ -1186,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(!v), nil + + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { @@ -1197,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { @@ -1208,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1219,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1231,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1243,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil case linux.TCP_KEEPCNT: if outLen < sizeOfInt32 { @@ -1255,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_USER_TIMEOUT: if outLen < sizeOfInt32 { @@ -1267,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Millisecond), nil + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1281,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1316,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + + bP := primitive.ByteSlice(b) + return &bP, nil case linux.TCP_LINGER2: if outLen < sizeOfInt32 { @@ -1328,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: if outLen < sizeOfInt32 { @@ -1340,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return nil, syserr.TranslateNetstackError(err) } - return int32(time.Duration(v) / time.Second), nil + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil case linux.TCP_SYNCNT: if outLen < sizeOfInt32 { @@ -1351,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_WINDOW_CLAMP: if outLen < sizeOfInt32 { @@ -1363,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1372,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1383,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1391,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil case linux.IPV6_RECVTCLASS: if outLen < sizeOfInt32 { @@ -1416,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIPv6(t, name) @@ -1425,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1438,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { @@ -1454,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1468,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { @@ -1479,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_RECVTOS: if outLen < sizeOfInt32 { @@ -1504,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { @@ -1515,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in if err != nil { return nil, syserr.TranslateNetstackError(err) } - return boolToInt32(v), nil + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: emitUnimplementedEventIP(t, name) @@ -1719,6 +1802,14 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument @@ -1733,6 +1824,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -2092,13 +2188,22 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -2313,7 +2418,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -2419,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. @@ -2474,6 +2596,11 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2708,11 +2835,16 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2720,9 +2852,7 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2741,9 +2871,8 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2752,52 +2881,49 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2810,9 +2936,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2826,9 +2951,8 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc } // Copy result to userspace. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2854,7 +2978,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2877,21 +3001,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2900,7 +3031,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2911,32 +3042,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2953,6 +3084,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2962,7 +3101,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -2999,9 +3138,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index d65a89316..a9025b0ec 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -31,6 +31,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketVFS2 rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { @@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. if err != nil { return nil, err } - return entries, nil + return &entries, nil } } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index ee11742a6..67737ae87 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,6 +15,8 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" @@ -40,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } @@ -143,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcpip.StackReceiveBufferSizeOption + var rs tcp.ReceiveBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -154,7 +166,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcpip.StackReceiveBufferSizeOption{ + rs := tcp.ReceiveBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -164,7 +176,7 @@ func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcpip.StackSendBufferSizeOption + var ss tcp.SendBufferSizeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -175,7 +187,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcpip.StackSendBufferSizeOption{ + ss := tcp.SendBufferSizeOption{ Min: size.Min, Default: size.Default, Max: size.Max, @@ -185,14 +197,14 @@ func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcpip.StackSACKEnabled + var sack tcp.SACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enabled))).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } // Statistics implements inet.Stack.Statistics. @@ -313,7 +325,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. 0, // Udp/SndbufErrors. - 0, // Udp/InCsumErrors. + udp.ChecksumErrors.Value(), // Udp/InCsumErrors. 0, // Udp/IgnoredMulti. } default: diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 6580bd6e9..d112757fb 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -86,7 +87,7 @@ type SocketOps interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -407,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { linux.SO_MARK, linux.SO_MAX_PACING_RATE, linux.SO_NOFCS, - linux.SO_NO_CHECK, linux.SO_OOBINLINE, linux.SO_PASSCRED, linux.SO_PASSSEC, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cca5e70f1..061a689a9 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -35,5 +35,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 09c6d3b27..a1e49cc57 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -476,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + if e.Connected() { return linux.SS_CONNECTED } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4bb2b6ff4..0482d33cf 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -40,6 +40,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -184,7 +185,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index ff2149250..05c16fcfe 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -89,7 +90,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index a6e48b836..5d51a7792 100644 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go @@ -50,10 +50,10 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui sb.WriteString("...") break } - if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok { - fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr) - continue - } + // Allowing addr to overflow is consistent with Linux, and harmless; if + // this isn't the last iteration of the loop, the next call to CopyIn + // will just fail with EFAULT. + addr, _ = addr.AddLength(uint64(linux.SizeOfEpollEvent)) } sb.WriteString("}") return sb.String() @@ -75,7 +75,7 @@ var epollEventEvents = abi.FlagSet{ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"}, {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"}, {Flag: linux.EPOLLERR, Name: "EPOLLERR"}, - {Flag: linux.EPOLLHUP, Name: "EPULLHUP"}, + {Flag: linux.EPOLLHUP, Name: "EPOLLHUP"}, {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"}, {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"}, {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"}, diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 217fcfef2..4a9b04fd0 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -99,5 +99,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index ea4f9b1a7..80c65164a 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -325,8 +325,8 @@ var AMD64 = &kernel.SyscallTable{ 270: syscalls.Supported("pselect", Pselect), 271: syscalls.Supported("ppoll", Ppoll), 272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil), - 273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil), - 274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil), + 273: syscalls.Supported("set_robust_list", SetRobustList), + 274: syscalls.Supported("get_robust_list", GetRobustList), 275: syscalls.Supported("splice", Splice), 276: syscalls.Supported("tee", Tee), 277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 696e1c8d3..8cf6401e7 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -900,14 +900,20 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) { +func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { a := file.Async(fasync.New).(*fasync.FileAsync) if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return syserror.EINVAL + } pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(-who)) a.SetOwnerProcessGroup(t, pg) + } else { + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who)) + a.SetOwnerThreadGroup(t, tg) } - tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(who)) - a.SetOwnerThreadGroup(t, tg) + return nil } // Fcntl implements linux syscall fcntl(2). @@ -1042,8 +1048,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - fSetOwn(t, file, args[2].Int()) - return 0, nil, nil + return 0, nil, fSetOwn(t, file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1052,7 +1057,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_SETOWN_EX: addr := args[2].Pointer() var owner linux.FOwnerEx - n, err := t.CopyIn(addr, &owner) + _, err := t.CopyIn(addr, &owner) if err != nil { return 0, nil, err } @@ -1064,21 +1069,21 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ESRCH } a.SetOwnerTask(t, task) - return uintptr(n), nil, nil + return 0, nil, nil case linux.F_OWNER_PID: tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(owner.PID)) if tg == nil { return 0, nil, syserror.ESRCH } a.SetOwnerThreadGroup(t, tg) - return uintptr(n), nil, nil + return 0, nil, nil case linux.F_OWNER_PGRP: pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(owner.PID)) if pg == nil { return 0, nil, syserror.ESRCH } a.SetOwnerProcessGroup(t, pg) - return uintptr(n), nil, nil + return 0, nil, nil default: return 0, nil, syserror.EINVAL } @@ -1111,17 +1116,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } -// LINT.ThenChange(vfs2/fd.go) - -const ( - _FADV_NORMAL = 0 - _FADV_RANDOM = 1 - _FADV_SEQUENTIAL = 2 - _FADV_WILLNEED = 3 - _FADV_DONTNEED = 4 - _FADV_NOREUSE = 5 -) - // Fadvise64 implements linux syscall fadvise64(2). // This implementation currently ignores the provided advice. func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { @@ -1146,12 +1140,12 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } switch advice { - case _FADV_NORMAL: - case _FADV_RANDOM: - case _FADV_SEQUENTIAL: - case _FADV_WILLNEED: - case _FADV_DONTNEED: - case _FADV_NOREUSE: + case linux.POSIX_FADV_NORMAL: + case linux.POSIX_FADV_RANDOM: + case linux.POSIX_FADV_SEQUENTIAL: + case linux.POSIX_FADV_WILLNEED: + case linux.POSIX_FADV_DONTNEED: + case linux.POSIX_FADV_NOREUSE: default: return 0, nil, syserror.EINVAL } @@ -1160,8 +1154,6 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } -// LINT.IfChange - func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index b68261f72..f04d78856 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -198,7 +198,7 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall switch cmd { case linux.FUTEX_WAIT: // WAIT uses a relative timeout. - mask = ^uint32(0) + mask = linux.FUTEX_BITSET_MATCH_ANY var timeoutDur time.Duration if !forever { timeoutDur = time.Duration(timespec.ToNsecCapped()) * time.Nanosecond @@ -286,3 +286,49 @@ func Futex(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, syserror.ENOSYS } } + +// SetRobustList implements linux syscall set_robust_list(2). +func SetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + head := args[0].Pointer() + length := args[1].SizeT() + + if length != uint(linux.SizeOfRobustListHead) { + return 0, nil, syserror.EINVAL + } + t.SetRobustList(head) + return 0, nil, nil +} + +// GetRobustList implements linux syscall get_robust_list(2). +func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // Despite the syscall using the name 'pid' for this variable, it is + // very much a tid. + tid := args[0].Int() + head := args[1].Pointer() + size := args[2].Pointer() + + if tid < 0 { + return 0, nil, syserror.EINVAL + } + + ot := t + if tid != 0 { + if ot = t.PIDNamespace().TaskWithID(kernel.ThreadID(tid)); ot == nil { + return 0, nil, syserror.ESRCH + } + } + + // Copy out head pointer. + if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil { + return 0, nil, err + } + + // Copy out size, which is a constant. + if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil { + return 0, nil, err + } + + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0760af77b..414fce8e3 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -29,6 +29,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -474,7 +476,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -484,7 +486,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -496,13 +498,16 @@ func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr use switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -539,7 +544,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index c301a0991..64696b438 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -54,6 +54,7 @@ go_library( "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/fasync", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", @@ -71,5 +72,7 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index f5eaa076b..67f191551 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/fasync" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -154,6 +155,41 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } return uintptr(n), nil, nil + case linux.F_GETOWN: + owner, hasOwner := getAsyncOwner(t, file) + if !hasOwner { + return 0, nil, nil + } + if owner.Type == linux.F_OWNER_PGRP { + return uintptr(-owner.PID), nil, nil + } + return uintptr(owner.PID), nil, nil + case linux.F_SETOWN: + who := args[2].Int() + ownerType := int32(linux.F_OWNER_PID) + if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return 0, nil, syserror.EINVAL + } + ownerType = linux.F_OWNER_PGRP + who = -who + } + return 0, nil, setAsyncOwner(t, file, ownerType, who) + case linux.F_GETOWN_EX: + owner, hasOwner := getAsyncOwner(t, file) + if !hasOwner { + return 0, nil, nil + } + _, err := t.CopyOut(args[2].Pointer(), &owner) + return 0, nil, err + case linux.F_SETOWN_EX: + var owner linux.FOwnerEx + _, err := t.CopyIn(args[2].Pointer(), &owner) + if err != nil { + return 0, nil, err + } + return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) case linux.F_GETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -177,6 +213,75 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } } +func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwnerEx, hasOwner bool) { + a := fd.AsyncHandler() + if a == nil { + return linux.FOwnerEx{}, false + } + + ot, otg, opg := a.(*fasync.FileAsync).Owner() + switch { + case ot != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_TID, + PID: int32(t.PIDNamespace().IDOfTask(ot)), + }, true + case otg != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_PID, + PID: int32(t.PIDNamespace().IDOfThreadGroup(otg)), + }, true + case opg != nil: + return linux.FOwnerEx{ + Type: linux.F_OWNER_PGRP, + PID: int32(t.PIDNamespace().IDOfProcessGroup(opg)), + }, true + default: + return linux.FOwnerEx{}, true + } +} + +func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { + switch ownerType { + case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: + // Acceptable type. + default: + return syserror.EINVAL + } + + a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + if pid == 0 { + a.ClearOwner() + return nil + } + + switch ownerType { + case linux.F_OWNER_TID: + task := t.PIDNamespace().TaskWithID(kernel.ThreadID(pid)) + if task == nil { + return syserror.ESRCH + } + a.SetOwnerTask(t, task) + return nil + case linux.F_OWNER_PID: + tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(pid)) + if tg == nil { + return syserror.ESRCH + } + a.SetOwnerThreadGroup(t, tg) + return nil + case linux.F_OWNER_PGRP: + pg := t.PIDNamespace().ProcessGroupWithID(kernel.ProcessGroupID(pid)) + if pg == nil { + return syserror.ESRCH + } + a.SetOwnerProcessGroup(t, pg) + return nil + default: + return syserror.EINVAL + } +} + func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescription, cmd int32) error { // Copy in the lock request. flockAddr := args[2].Pointer() @@ -210,3 +315,41 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip return syserror.EINVAL } } + +// Fadvise64 implements fadvise64(2). +// This implementation currently ignores the provided advice. +func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + length := args[2].Int64() + advice := args[3].Int() + + // Note: offset is allowed to be negative. + if length < 0 { + return 0, nil, syserror.EINVAL + } + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // If the FD refers to a pipe or FIFO, return error. + if _, isPipe := file.Impl().(*pipe.VFSPipeFD); isPipe { + return 0, nil, syserror.ESPIPE + } + + switch advice { + case linux.POSIX_FADV_NORMAL: + case linux.POSIX_FADV_RANDOM: + case linux.POSIX_FADV_SEQUENTIAL: + case linux.POSIX_FADV_WILLNEED: + case linux.POSIX_FADV_DONTNEED: + case linux.POSIX_FADV_NOREUSE: + default: + return 0, nil, syserror.EINVAL + } + + // Sure, whatever. + return 0, nil, nil +} diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go index 46d3e189c..6b14c2bef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/filesystem.go +++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -106,7 +107,7 @@ func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall addr := args[0].Pointer() mode := args[1].ModeT() dev := args[2].Uint() - return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev) + return 0, nil, mknodat(t, linux.AT_FDCWD, addr, linux.FileMode(mode), dev) } // Mknodat implements Linux syscall mknodat(2). @@ -115,10 +116,10 @@ func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca addr := args[1].Pointer() mode := args[2].ModeT() dev := args[3].Uint() - return 0, nil, mknodat(t, dirfd, addr, mode, dev) + return 0, nil, mknodat(t, dirfd, addr, linux.FileMode(mode), dev) } -func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error { +func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode linux.FileMode, dev uint32) error { path, err := copyInPath(t, addr) if err != nil { return err @@ -128,9 +129,14 @@ func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint return err } defer tpop.Release() + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + if mode.FileType() == 0 { + mode |= linux.ModeRegular + } major, minor := linux.DecodeDeviceID(dev) return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{ - Mode: linux.FileMode(mode &^ t.FSContext().Umask()), + Mode: mode &^ linux.FileMode(t.FSContext().Umask()), DevMajor: uint32(major), DevMinor: minor, }) @@ -239,6 +245,55 @@ func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd }) } +// Fallocate implements linux system call fallocate(2). +func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + mode := args[1].Uint64() + offset := args[2].Int64() + length := args[3].Int64() + + file := t.GetFileVFS2(fd) + + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + if !file.IsWritable() { + return 0, nil, syserror.EBADF + } + + if mode != 0 { + return 0, nil, syserror.ENOTSUP + } + + if offset < 0 || length <= 0 { + return 0, nil, syserror.EINVAL + } + + size := offset + length + + if size < 0 { + return 0, nil, syserror.EFBIG + } + + limit := limits.FromContext(t).Get(limits.FileSize).Cur + + if uint64(size) >= limit { + t.SendSignal(&arch.SignalInfo{ + Signo: int32(linux.SIGXFSZ), + Code: arch.SignalInfoUser, + }) + return 0, nil, syserror.EFBIG + } + + return 0, nil, file.Impl().Allocate(t, mode, uint64(offset), uint64(length)) + + // File length modified, generate notification. + // TODO(gvisor.dev/issue/1479): Reenable when Inotify is ported. + // file.Dirent.InotifyEvent(linux.IN_MODIFY, 0) +} + // Rmdir implements Linux syscall rmdir(2). func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { pathAddr := args[0].Pointer() @@ -313,6 +368,9 @@ func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpath if err != nil { return err } + if len(target) == 0 { + return syserror.ENOENT + } linkpath, err := copyInPath(t, linkpathAddr) if err != nil { return err diff --git a/pkg/sentry/syscalls/linux/vfs2/inotify.go b/pkg/sentry/syscalls/linux/vfs2/inotify.go index 7d50b6a16..5d98134a5 100644 --- a/pkg/sentry/syscalls/linux/vfs2/inotify.go +++ b/pkg/sentry/syscalls/linux/vfs2/inotify.go @@ -81,7 +81,7 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern // "EINVAL: The given event mask contains no valid events." // -- inotify_add_watch(2) - if validBits := mask & linux.ALL_INOTIFY_BITS; validBits == 0 { + if mask&linux.ALL_INOTIFY_BITS == 0 { return 0, nil, syserror.EINVAL } @@ -116,8 +116,11 @@ func InotifyAddWatch(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kern } defer d.DecRef() - fd = ino.AddWatch(d.Dentry(), mask) - return uintptr(fd), nil, err + fd, err = ino.AddWatch(d.Dentry(), mask) + if err != nil { + return 0, nil, err + } + return uintptr(fd), nil, nil } // InotifyRmWatch implements the inotify_rm_watch() syscall. diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 0399c0db4..fd6ab94b2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -57,6 +57,49 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall flags &^= linux.O_NONBLOCK } return 0, nil, file.SetStatusFlags(t, t.Credentials(), flags) + + case linux.FIOASYNC: + var set int32 + if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + return 0, nil, err + } + flags := file.StatusFlags() + if set != 0 { + flags |= linux.O_ASYNC + } else { + flags &^= linux.O_ASYNC + } + file.SetStatusFlags(t, t.Credentials(), flags) + return 0, nil, nil + + case linux.FIOGETOWN, linux.SIOCGPGRP: + var who int32 + owner, hasOwner := getAsyncOwner(t, file) + if hasOwner { + if owner.Type == linux.F_OWNER_PGRP { + who = -owner.PID + } else { + who = owner.PID + } + } + _, err := t.CopyOut(args[2].Pointer(), &who) + return 0, nil, err + + case linux.FIOSETOWN, linux.SIOCSPGRP: + var who int32 + if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil { + return 0, nil, err + } + ownerType := int32(linux.F_OWNER_PID) + if who < 0 { + // Check for overflow before flipping the sign. + if who-1 > who { + return 0, nil, syserror.EINVAL + } + ownerType = linux.F_OWNER_PGRP + who = -who + } + return 0, nil, setAsyncOwner(t, file, ownerType, who) } ret, err := file.Ioctl(t, t.MemoryManager(), args) diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index adeaa39cc..ea337de7c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -77,8 +77,7 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Silently allow MS_NOSUID, since we don't implement set-id bits // anyway. - const unsupportedFlags = linux.MS_NODEV | - linux.MS_NODIRATIME | linux.MS_STRICTATIME + const unsupportedFlags = linux.MS_NODIRATIME | linux.MS_STRICTATIME // Linux just allows passing any flags to mount(2) - it won't fail when // unknown or unsupported flags are passed. Since we don't implement @@ -94,6 +93,12 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if flags&linux.MS_NOEXEC == linux.MS_NOEXEC { opts.Flags.NoExec = true } + if flags&linux.MS_NODEV == linux.MS_NODEV { + opts.Flags.NoDev = true + } + if flags&linux.MS_NOSUID == linux.MS_NOSUID { + opts.Flags.NoSUID = true + } if flags&linux.MS_RDONLY == linux.MS_RDONLY { opts.ReadOnly = true } diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index 7f9debd4a..cd25597a7 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -606,3 +606,36 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall newoff, err := file.Seek(t, offset, whence) return uintptr(newoff), nil, err } + +// Readahead implements readahead(2). +func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + fd := args[0].Int() + offset := args[1].Int64() + size := args[2].SizeT() + + file := t.GetFileVFS2(fd) + if file == nil { + return 0, nil, syserror.EBADF + } + defer file.DecRef() + + // Check that the file is readable. + if !file.IsReadable() { + return 0, nil, syserror.EBADF + } + + // Check that the size is valid. + if int(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Check that the offset is legitimate and does not overflow. + if offset < 0 || offset+int64(size) < 0 { + return 0, nil, syserror.EINVAL + } + + // Return EINVAL; if the underlying file type does not support readahead, + // then Linux will return EINVAL to indicate as much. In the future, we + // may extend this function to actually support readahead hints. + return 0, nil, syserror.EINVAL +} diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 09ecfed26..6daedd173 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -178,6 +178,7 @@ func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc Mask: linux.STATX_SIZE, Size: uint64(length), }, + NeedWritePerm: true, }) return 0, nil, handleSetSizeError(t, err) } @@ -197,6 +198,10 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } defer file.DecRef() + if !file.IsWritable() { + return 0, nil, syserror.EINVAL + } + err := file.SetStat(t, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_SIZE, diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 10b668477..8096a8f9c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -30,6 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -477,7 +479,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } if v != nil { - if _, err := t.CopyOut(optValAddr, v); err != nil { + if _, err := v.CopyOut(t, optValAddr); err != nil { return 0, nil, err } } @@ -487,7 +489,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy // getSockOpt tries to handle common socket options, or dispatches to a specific // socket implementation. -func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (interface{}, *syserr.Error) { +func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr usermem.Addr, len int) (marshal.Marshallable, *syserr.Error) { if level == linux.SOL_SOCKET { switch name { case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL: @@ -499,13 +501,16 @@ func getSockOpt(t *kernel.Task, s socket.SocketVFS2, level, name int, optValAddr switch name { case linux.SO_TYPE: _, skType, _ := s.Type() - return int32(skType), nil + v := primitive.Int32(skType) + return &v, nil case linux.SO_DOMAIN: family, _, _ := s.Type() - return int32(family), nil + v := primitive.Int32(family) + return &v, nil case linux.SO_PROTOCOL: _, _, protocol := s.Type() - return int32(protocol), nil + v := primitive.Int32(protocol) + return &v, nil } } @@ -542,7 +547,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.EINVAL } buf := t.CopyScratchBuffer(int(optLen)) - if _, err := t.CopyIn(optValAddr, &buf); err != nil { + if _, err := t.CopyInBytes(optValAddr, buf); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 945a364a7..63ab11f8c 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -15,12 +15,15 @@ package vfs2 import ( + "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -110,16 +113,20 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Move data. var ( - n int64 - err error - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { // If both input and output are pipes, delegate to the pipe - // implementation. Otherwise, exactly one end is a pipe, which we - // ensure is consistently ordered after the non-pipe FD's locks by - // passing the pipe FD as usermem.IO to the non-pipe end. + // implementation. Otherwise, exactly one end is a pipe, which + // we ensure is consistently ordered after the non-pipe FD's + // locks by passing the pipe FD as usermem.IO to the non-pipe + // end. switch { case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) @@ -137,38 +144,15 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } else { n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } + default: + panic("not possible") } + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } - - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the splice operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. - } - if err = t.Block(inCh); err != nil { - break - } - } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + if err = dw.waitForBoth(t); err != nil { + break } } @@ -247,45 +231,256 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo // Copy data. var ( - inCh chan struct{} - outCh chan struct{} + n int64 + err error ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() for { - n, err := pipe.Tee(t, outPipeFD, inPipeFD, count) - if n != 0 { - return uintptr(n), nil, nil + n, err = pipe.Tee(t, outPipeFD, inPipeFD, count) + if n != 0 || err != syserror.ErrWouldBlock || nonBlock { + break + } + if err = dw.waitForBoth(t); err != nil { + break + } + } + if n == 0 { + return 0, nil, err + } + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// Sendfile implements linux system call sendfile(2). +func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + outFD := args[0].Int() + inFD := args[1].Int() + offsetAddr := args[2].Pointer() + count := int64(args[3].SizeT()) + + inFile := t.GetFileVFS2(inFD) + if inFile == nil { + return 0, nil, syserror.EBADF + } + defer inFile.DecRef() + if !inFile.IsReadable() { + return 0, nil, syserror.EBADF + } + + outFile := t.GetFileVFS2(outFD) + if outFile == nil { + return 0, nil, syserror.EBADF + } + defer outFile.DecRef() + if !outFile.IsWritable() { + return 0, nil, syserror.EBADF + } + + // Verify that the outFile Append flag is not set. + if outFile.StatusFlags()&linux.O_APPEND != 0 { + return 0, nil, syserror.EINVAL + } + + // Verify that inFile is a regular file or block device. This is a + // requirement; the same check appears in Linux + // (fs/splice.c:splice_direct_to_actor). + if stat, err := inFile.Stat(t, vfs.StatOptions{Mask: linux.STATX_TYPE}); err != nil { + return 0, nil, err + } else if stat.Mask&linux.STATX_TYPE == 0 || + (stat.Mode&linux.S_IFMT != linux.S_IFREG && stat.Mode&linux.S_IFMT != linux.S_IFBLK) { + return 0, nil, syserror.EINVAL + } + + // Copy offset if it exists. + offset := int64(-1) + if offsetAddr != 0 { + if inFile.Options().DenyPRead { + return 0, nil, syserror.ESPIPE } - if err != syserror.ErrWouldBlock || nonBlock { + if _, err := t.CopyIn(offsetAddr, &offset); err != nil { return 0, nil, err } + if offset < 0 { + return 0, nil, syserror.EINVAL + } + if offset+count < 0 { + return 0, nil, syserror.EINVAL + } + } + + // Validate count. This must come after offset checks. + if count < 0 { + return 0, nil, syserror.EINVAL + } + if count == 0 { + return 0, nil, nil + } + if count > int64(kernel.MAX_RW_COUNT) { + count = int64(kernel.MAX_RW_COUNT) + } - // Note that the blocking behavior here is a bit different than the - // normal pattern. Because we need to have both data to read and data - // to write simultaneously, we actually explicitly block on both of - // these cases in turn before returning to the tee operation. - if inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { - if inCh == nil { - inCh = make(chan struct{}, 1) - inW, _ := waiter.NewChannelEntry(inCh) - inFile.EventRegister(&inW, eventMaskRead) - defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Copy data. + var ( + n int64 + err error + ) + dw := dualWaiter{ + inFile: inFile, + outFile: outFile, + } + defer dw.destroy() + outPipeFD, outIsPipe := outFile.Impl().(*pipe.VFSPipeFD) + // Reading from input file should never block, since it is regular or + // block device. We only need to check if writing to the output file + // can block. + nonBlock := outFile.StatusFlags()&linux.O_NONBLOCK != 0 + if outIsPipe { + for n < count { + var spliceN int64 + if offset != -1 { + spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) + offset += spliceN + } else { + spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } - if err := t.Block(inCh); err != nil { - return 0, nil, err + n += spliceN + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) + } + if err != nil { + break } } - if outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, eventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. + } else { + // Read inFile to buffer, then write the contents to outFile. + buf := make([]byte, count) + for n < count { + var readN int64 + if offset != -1 { + readN, err = inFile.PRead(t, usermem.BytesIOSequence(buf), offset, vfs.ReadOptions{}) + offset += readN + } else { + readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + } + if readN == 0 && err == io.EOF { + // We reached the end of the file. Eat the + // error and exit the loop. + err = nil + break } - if err := t.Block(outCh); err != nil { - return 0, nil, err + n += readN + if err != nil { + break + } + + // Write all of the bytes that we read. This may need + // multiple write calls to complete. + wbuf := buf[:n] + for len(wbuf) > 0 { + var writeN int64 + writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) + wbuf = wbuf[writeN:] + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForOut(t) + } + if err != nil { + // We didn't complete the write. Only + // report the bytes that were actually + // written, and rewind the offset. + notWritten := int64(len(wbuf)) + n -= notWritten + if offset != -1 { + offset -= notWritten + } + break + } + } + if err == syserror.ErrWouldBlock && !nonBlock { + err = dw.waitForBoth(t) } + if err != nil { + break + } + } + } + + if offsetAddr != 0 { + // Copy out the new offset. + if _, err := t.CopyOut(offsetAddr, offset); err != nil { + return 0, nil, err + } + } + + if n == 0 { + return 0, nil, err + } + + inFile.Dentry().InotifyWithParent(linux.IN_ACCESS, 0, vfs.PathEvent) + outFile.Dentry().InotifyWithParent(linux.IN_MODIFY, 0, vfs.PathEvent) + return uintptr(n), nil, nil +} + +// dualWaiter is used to wait on one or both vfs.FileDescriptions. It is not +// thread-safe, and does not take a reference on the vfs.FileDescriptions. +// +// Users must call destroy() when finished. +type dualWaiter struct { + inFile *vfs.FileDescription + outFile *vfs.FileDescription + + inW waiter.Entry + inCh chan struct{} + outW waiter.Entry + outCh chan struct{} +} + +// waitForBoth waits for both dw.inFile and dw.outFile to be ready. +func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { + if dw.inFile.Readiness(eventMaskRead)&eventMaskRead == 0 { + if dw.inCh == nil { + dw.inW, dw.inCh = waiter.NewChannelEntry(nil) + dw.inFile.EventRegister(&dw.inW, eventMaskRead) + // We might be ready now. Try again before blocking. + return nil + } + if err := t.Block(dw.inCh); err != nil { + return err + } + } + return dw.waitForOut(t) +} + +// waitForOut waits for dw.outfile to be read. +func (dw *dualWaiter) waitForOut(t *kernel.Task) error { + if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready now. Try again before blocking. + return nil } + if err := t.Block(dw.outCh); err != nil { + return err + } + } + return nil +} + +// destroy cleans up resources help by dw. No more calls to wait* can occur +// after destroy is called. +func (dw *dualWaiter) destroy() { + if dw.inCh != nil { + dw.inFile.EventUnregister(&dw.inW) + dw.inCh = nil + } + if dw.outCh != nil { + dw.outFile.EventUnregister(&dw.outW) + dw.outCh = nil } + dw.inFile = nil + dw.outFile = nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index 365250b0b..0d0ebf46a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -65,10 +65,8 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel nbytes := args[2].Int64() flags := args[3].Uint() - if offset < 0 { - return 0, nil, syserror.EINVAL - } - if nbytes < 0 { + // Check for negative values and overflow. + if offset < 0 || offset+nbytes < 0 { return 0, nil, syserror.EINVAL } if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 { @@ -81,7 +79,37 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel } defer file.DecRef() - // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of - // [offset, offset+nbytes). - return 0, nil, file.Sync(t) + // TODO(gvisor.dev/issue/1897): Currently, the only file syncing we support + // is a full-file sync, i.e. fsync(2). As a result, there are severe + // limitations on how much we support sync_file_range: + // - In Linux, sync_file_range(2) doesn't write out the file's metadata, even + // if the file size is changed. We do. + // - We always sync the entire file instead of [offset, offset+nbytes). + // - We do not support the use of WAIT_BEFORE without WAIT_AFTER. For + // correctness, we would have to perform a write-out every time WAIT_BEFORE + // was used, but this would be much more expensive than expected if there + // were no write-out operations in progress. + // - Whenever WAIT_AFTER is used, we sync the file. + // - Ignore WRITE. If this flag is used with WAIT_AFTER, then the file will + // be synced anyway. If this flag is used without WAIT_AFTER, then it is + // safe (and less expensive) to do nothing, because the syscall will not + // wait for the write-out to complete--we only need to make sure that the + // next time WAIT_BEFORE or WAIT_AFTER are used, the write-out completes. + // - According to fs/sync.c, WAIT_BEFORE|WAIT_AFTER "will detect any I/O + // errors or ENOSPC conditions and will return those to the caller, after + // clearing the EIO and ENOSPC flags in the address_space." We don't do + // this. + + if flags&linux.SYNC_FILE_RANGE_WAIT_BEFORE != 0 && + flags&linux.SYNC_FILE_RANGE_WAIT_AFTER == 0 { + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, syserror.ENOSYS + } + + if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 { + if err := file.Sync(t); err != nil { + return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + } + } + return 0, nil, nil } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index caa6a98ff..c576d9475 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -44,7 +44,7 @@ func Override() { s.Table[23] = syscalls.Supported("select", Select) s.Table[32] = syscalls.Supported("dup", Dup) s.Table[33] = syscalls.Supported("dup2", Dup2) - delete(s.Table, 40) // sendfile + s.Table[40] = syscalls.Supported("sendfile", Sendfile) s.Table[41] = syscalls.Supported("socket", Socket) s.Table[42] = syscalls.Supported("connect", Connect) s.Table[43] = syscalls.Supported("accept", Accept) @@ -62,7 +62,7 @@ func Override() { s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt) s.Table[59] = syscalls.Supported("execve", Execve) s.Table[72] = syscalls.Supported("fcntl", Fcntl) - s.Table[73] = syscalls.Supported("fcntl", Flock) + s.Table[73] = syscalls.Supported("flock", Flock) s.Table[74] = syscalls.Supported("fsync", Fsync) s.Table[75] = syscalls.Supported("fdatasync", Fdatasync) s.Table[76] = syscalls.Supported("truncate", Truncate) @@ -92,7 +92,7 @@ func Override() { s.Table[162] = syscalls.Supported("sync", Sync) s.Table[165] = syscalls.Supported("mount", Mount) s.Table[166] = syscalls.Supported("umount2", Umount2) - delete(s.Table, 187) // readahead + s.Table[187] = syscalls.Supported("readahead", Readahead) s.Table[188] = syscalls.Supported("setxattr", Setxattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr) @@ -108,7 +108,7 @@ func Override() { s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) s.Table[213] = syscalls.Supported("epoll_create", EpollCreate) s.Table[217] = syscalls.Supported("getdents64", Getdents64) - delete(s.Table, 221) // fdavise64 + s.Table[221] = syscalls.PartiallySupported("fadvise64", Fadvise64, "The syscall is 'supported', but ignores all provided advice.", nil) s.Table[232] = syscalls.Supported("epoll_wait", EpollWait) s.Table[233] = syscalls.Supported("epoll_ctl", EpollCtl) s.Table[235] = syscalls.Supported("utimes", Utimes) @@ -138,7 +138,7 @@ func Override() { s.Table[282] = syscalls.Supported("signalfd", Signalfd) s.Table[283] = syscalls.Supported("timerfd_create", TimerfdCreate) s.Table[284] = syscalls.Supported("eventfd", Eventfd) - delete(s.Table, 285) // fallocate + s.Table[285] = syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil) s.Table[286] = syscalls.Supported("timerfd_settime", TimerfdSettime) s.Table[287] = syscalls.Supported("timerfd_gettime", TimerfdGettime) s.Table[288] = syscalls.Supported("accept4", Accept4) @@ -163,6 +163,106 @@ func Override() { // Override ARM64. s = linux.ARM64 + s.Table[5] = syscalls.Supported("setxattr", Setxattr) + s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) + s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) + s.Table[8] = syscalls.Supported("getxattr", Getxattr) + s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr) + s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr) + s.Table[11] = syscalls.Supported("listxattr", Listxattr) + s.Table[12] = syscalls.Supported("llistxattr", Llistxattr) + s.Table[13] = syscalls.Supported("flistxattr", Flistxattr) + s.Table[14] = syscalls.Supported("removexattr", Removexattr) + s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr) + s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr) + s.Table[17] = syscalls.Supported("getcwd", Getcwd) + s.Table[19] = syscalls.Supported("eventfd2", Eventfd2) + s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1) + s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl) + s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait) + s.Table[23] = syscalls.Supported("dup", Dup) + s.Table[24] = syscalls.Supported("dup3", Dup3) + s.Table[25] = syscalls.Supported("fcntl", Fcntl) + s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil) + s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil) + s.Table[29] = syscalls.Supported("ioctl", Ioctl) + s.Table[32] = syscalls.Supported("flock", Flock) + s.Table[33] = syscalls.Supported("mknodat", Mknodat) + s.Table[34] = syscalls.Supported("mkdirat", Mkdirat) + s.Table[35] = syscalls.Supported("unlinkat", Unlinkat) + s.Table[36] = syscalls.Supported("symlinkat", Symlinkat) + s.Table[37] = syscalls.Supported("linkat", Linkat) + s.Table[38] = syscalls.Supported("renameat", Renameat) + s.Table[39] = syscalls.Supported("umount2", Umount2) + s.Table[40] = syscalls.Supported("mount", Mount) + s.Table[43] = syscalls.Supported("statfs", Statfs) + s.Table[44] = syscalls.Supported("fstatfs", Fstatfs) + s.Table[45] = syscalls.Supported("truncate", Truncate) + s.Table[46] = syscalls.Supported("ftruncate", Ftruncate) + s.Table[48] = syscalls.Supported("faccessat", Faccessat) + s.Table[49] = syscalls.Supported("chdir", Chdir) + s.Table[50] = syscalls.Supported("fchdir", Fchdir) + s.Table[51] = syscalls.Supported("chroot", Chroot) + s.Table[52] = syscalls.Supported("fchmod", Fchmod) + s.Table[53] = syscalls.Supported("fchmodat", Fchmodat) + s.Table[54] = syscalls.Supported("fchownat", Fchownat) + s.Table[55] = syscalls.Supported("fchown", Fchown) + s.Table[56] = syscalls.Supported("openat", Openat) + s.Table[57] = syscalls.Supported("close", Close) + s.Table[59] = syscalls.Supported("pipe2", Pipe2) + s.Table[61] = syscalls.Supported("getdents64", Getdents64) + s.Table[62] = syscalls.Supported("lseek", Lseek) s.Table[63] = syscalls.Supported("read", Read) + s.Table[64] = syscalls.Supported("write", Write) + s.Table[65] = syscalls.Supported("readv", Readv) + s.Table[66] = syscalls.Supported("writev", Writev) + s.Table[67] = syscalls.Supported("pread64", Pread64) + s.Table[68] = syscalls.Supported("pwrite64", Pwrite64) + s.Table[69] = syscalls.Supported("preadv", Preadv) + s.Table[70] = syscalls.Supported("pwritev", Pwritev) + s.Table[72] = syscalls.Supported("pselect", Pselect) + s.Table[73] = syscalls.Supported("ppoll", Ppoll) + s.Table[74] = syscalls.Supported("signalfd4", Signalfd4) + s.Table[76] = syscalls.Supported("splice", Splice) + s.Table[77] = syscalls.Supported("tee", Tee) + s.Table[78] = syscalls.Supported("readlinkat", Readlinkat) + s.Table[80] = syscalls.Supported("fstat", Fstat) + s.Table[81] = syscalls.Supported("sync", Sync) + s.Table[82] = syscalls.Supported("fsync", Fsync) + s.Table[83] = syscalls.Supported("fdatasync", Fdatasync) + s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange) + s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate) + s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime) + s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime) + s.Table[88] = syscalls.Supported("utimensat", Utimensat) + s.Table[198] = syscalls.Supported("socket", Socket) + s.Table[199] = syscalls.Supported("socketpair", SocketPair) + s.Table[200] = syscalls.Supported("bind", Bind) + s.Table[201] = syscalls.Supported("listen", Listen) + s.Table[202] = syscalls.Supported("accept", Accept) + s.Table[203] = syscalls.Supported("connect", Connect) + s.Table[204] = syscalls.Supported("getsockname", GetSockName) + s.Table[205] = syscalls.Supported("getpeername", GetPeerName) + s.Table[206] = syscalls.Supported("sendto", SendTo) + s.Table[207] = syscalls.Supported("recvfrom", RecvFrom) + s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt) + s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt) + s.Table[210] = syscalls.Supported("shutdown", Shutdown) + s.Table[211] = syscalls.Supported("sendmsg", SendMsg) + s.Table[212] = syscalls.Supported("recvmsg", RecvMsg) + s.Table[221] = syscalls.Supported("execve", Execve) + s.Table[222] = syscalls.Supported("mmap", Mmap) + s.Table[242] = syscalls.Supported("accept4", Accept4) + s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg) + s.Table[267] = syscalls.Supported("syncfs", Syncfs) + s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg) + s.Table[276] = syscalls.Supported("renameat2", Renameat2) + s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate) + s.Table[281] = syscalls.Supported("execveat", Execveat) + s.Table[286] = syscalls.Supported("preadv2", Preadv2) + s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) + s.Table[291] = syscalls.Supported("statx", Statx) + s.Init() } diff --git a/pkg/sentry/time/parameters.go b/pkg/sentry/time/parameters.go index 65868cb26..cd1b95117 100644 --- a/pkg/sentry/time/parameters.go +++ b/pkg/sentry/time/parameters.go @@ -228,11 +228,15 @@ func errorAdjust(prevParams Parameters, newParams Parameters, now TSCValue) (Par // // The log level is determined by the error severity. func logErrorAdjustment(clock ClockID, errorNS ReferenceNS, orig, adjusted Parameters) { - fn := log.Debugf - if int64(errorNS.Magnitude()) > time.Millisecond.Nanoseconds() { + magNS := int64(errorNS.Magnitude()) + if magNS <= 10*time.Microsecond.Nanoseconds() { + // Don't log small errors. + return + } + fn := log.Infof + if magNS > time.Millisecond.Nanoseconds() { + // Upgrade large errors to warning. fn = log.Warningf - } else if int64(errorNS.Magnitude()) > 10*time.Microsecond.Nanoseconds() { - fn = log.Infof } fn("Clock(%v): error: %v ns, adjusted frequency from %v Hz to %v Hz", clock, errorNS, orig.Frequency, adjusted.Frequency) diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index b7c6b60b8..641e3e502 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -300,12 +300,15 @@ func (d *anonDentry) DecRef() { // InotifyWithParent implements DentryImpl.InotifyWithParent. // -// TODO(gvisor.dev/issue/1479): Implement inotify. -func (d *anonDentry) InotifyWithParent(events uint32, cookie uint32, et EventType) {} +// Although Linux technically supports inotify on pseudo filesystems (inotify +// is implemented at the vfs layer), it is not particularly useful. It is left +// unimplemented until someone actually needs it. +func (d *anonDentry) InotifyWithParent(events, cookie uint32, et EventType) {} // Watches implements DentryImpl.Watches. -// -// TODO(gvisor.dev/issue/1479): Implement inotify. func (d *anonDentry) Watches() *Watches { return nil } + +// OnZeroWatches implements Dentry.OnZeroWatches. +func (d *anonDentry) OnZeroWatches() {} diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index 24af13eb1..cea3e6955 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -113,12 +113,29 @@ type DentryImpl interface { // // Note that the events may not actually propagate up to the user, depending // on the event masks. - InotifyWithParent(events uint32, cookie uint32, et EventType) + InotifyWithParent(events, cookie uint32, et EventType) // Watches returns the set of inotify watches for the file corresponding to // the Dentry. Dentries that are hard links to the same underlying file // share the same watches. + // + // Watches may return nil if the dentry belongs to a FilesystemImpl that + // does not support inotify. If an implementation returns a non-nil watch + // set, it must always return a non-nil watch set. Likewise, if an + // implementation returns a nil watch set, it must always return a nil watch + // set. + // + // The caller does not need to hold a reference on the dentry. Watches() *Watches + + // OnZeroWatches is called whenever the number of watches on a dentry drops + // to zero. This is needed by some FilesystemImpls (e.g. gofer) to manage + // dentry lifetime. + // + // The caller does not need to hold a reference on the dentry. OnZeroWatches + // may acquire inotify locks, so to prevent deadlock, no inotify locks should + // be held by the caller. + OnZeroWatches() } // IncRef increments d's reference count. @@ -149,17 +166,26 @@ func (d *Dentry) isMounted() bool { return atomic.LoadUint32(&d.mounts) != 0 } -// InotifyWithParent notifies all watches on the inodes for this dentry and +// InotifyWithParent notifies all watches on the targets represented by d and // its parent of events. -func (d *Dentry) InotifyWithParent(events uint32, cookie uint32, et EventType) { +func (d *Dentry) InotifyWithParent(events, cookie uint32, et EventType) { d.impl.InotifyWithParent(events, cookie, et) } // Watches returns the set of inotify watches associated with d. +// +// Watches will return nil if d belongs to a FilesystemImpl that does not +// support inotify. func (d *Dentry) Watches() *Watches { return d.impl.Watches() } +// OnZeroWatches performs cleanup tasks whenever the number of watches on a +// dentry drops to zero. +func (d *Dentry) OnZeroWatches() { + d.impl.OnZeroWatches() +} + // The following functions are exported so that filesystem implementations can // use them. The vfs package, and users of VFS, should not call these // functions. diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 599c3131c..5b009b928 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -186,7 +186,7 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin } // Register interest in file. - mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP + mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP epi := &epollInterest{ epoll: ep, key: key, @@ -257,7 +257,7 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event } // Update epi for the next call to ep.ReadEvents(). - mask := event.Events | linux.EPOLLERR | linux.EPOLLRDHUP + mask := event.Events | linux.EPOLLERR | linux.EPOLLHUP ep.mu.Lock() epi.mask = mask epi.userData = event.Data diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index e0538ea53..0c42574db 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -42,11 +42,20 @@ type FileDescription struct { // operations. refs int64 + // flagsMu protects statusFlags and asyncHandler below. + flagsMu sync.Mutex + // statusFlags contains status flags, "initialized by open(2) and possibly - // modified by fcntl()" - fcntl(2). statusFlags is accessed using atomic - // memory operations. + // modified by fcntl()" - fcntl(2). statusFlags can be read using atomic + // memory operations when it does not need to be synchronized with an + // access to asyncHandler. statusFlags uint32 + // asyncHandler handles O_ASYNC signal generation. It is set with the + // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must + // also be set by fcntl(2). + asyncHandler FileAsync + // epolls is the set of epollInterests registered for this FileDescription. // epolls is protected by epollMu. epollMu sync.Mutex @@ -82,8 +91,7 @@ type FileDescription struct { // FileDescriptionOptions contains options to FileDescription.Init(). type FileDescriptionOptions struct { - // If AllowDirectIO is true, allow O_DIRECT to be set on the file. This is - // usually only the case if O_DIRECT would actually have an effect. + // If AllowDirectIO is true, allow O_DIRECT to be set on the file. AllowDirectIO bool // If DenyPRead is true, calls to FileDescription.PRead() return ESPIPE. @@ -193,6 +201,13 @@ func (fd *FileDescription) DecRef() { fd.vd.mount.EndWrite() } fd.vd.DecRef() + fd.flagsMu.Lock() + // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } + fd.asyncHandler = nil + fd.flagsMu.Unlock() } else if refs < 0 { panic("FileDescription.DecRef() called without holding a reference") } @@ -276,7 +291,18 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede } // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK - atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) + fd.flagsMu.Lock() + if fd.asyncHandler != nil { + // Use fd.statusFlags instead of oldFlags, which may have become outdated, + // to avoid double registering/unregistering. + if fd.statusFlags&linux.O_ASYNC == 0 && flags&linux.O_ASYNC != 0 { + fd.asyncHandler.Register(fd) + } else if fd.statusFlags&linux.O_ASYNC != 0 && flags&linux.O_ASYNC == 0 { + fd.asyncHandler.Unregister(fd) + } + } + fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags) + fd.flagsMu.Unlock() return nil } @@ -328,6 +354,10 @@ type FileDescriptionImpl interface { // represented by the FileDescription. StatFS(ctx context.Context) (linux.Statfs, error) + // Allocate grows file represented by FileDescription to offset + length bytes. + // Only mode == 0 is supported currently. + Allocate(ctx context.Context, mode, offset, length uint64) error + // waiter.Waitable methods may be used to poll for I/O events. waiter.Waitable @@ -533,17 +563,23 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.impl.StatFS(ctx) } -// Readiness returns fd's I/O readiness. +// Readiness implements waiter.Waitable.Readiness. +// +// It returns fd's I/O readiness. func (fd *FileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { return fd.impl.Readiness(mask) } -// EventRegister registers e for I/O readiness events in mask. +// EventRegister implements waiter.Waitable.EventRegister. +// +// It registers e for I/O readiness events in mask. func (fd *FileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { fd.impl.EventRegister(e, mask) } -// EventUnregister unregisters e for I/O readiness events. +// EventUnregister implements waiter.Waitable.EventUnregister. +// +// It unregisters e for I/O readiness events. func (fd *FileDescription) EventUnregister(e *waiter.Entry) { fd.impl.EventUnregister(e) } @@ -770,3 +806,32 @@ func (fd *FileDescription) LockPOSIX(ctx context.Context, uid lock.UniqueID, t l func (fd *FileDescription) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, start, end uint64, whence int16) error { return fd.impl.UnlockPOSIX(ctx, uid, start, end, whence) } + +// A FileAsync sends signals to its owner when w is ready for IO. This is only +// implemented by pkg/sentry/fasync:FileAsync, but we unfortunately need this +// interface to avoid circular dependencies. +type FileAsync interface { + Register(w waiter.Waitable) + Unregister(w waiter.Waitable) +} + +// AsyncHandler returns the FileAsync for fd. +func (fd *FileDescription) AsyncHandler() FileAsync { + fd.flagsMu.Lock() + defer fd.flagsMu.Unlock() + return fd.asyncHandler +} + +// SetAsyncHandler sets fd.asyncHandler if it has not been set before and +// returns it. +func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsync { + fd.flagsMu.Lock() + defer fd.flagsMu.Unlock() + if fd.asyncHandler == nil { + fd.asyncHandler = newHandler() + if fd.statusFlags&linux.O_ASYNC != 0 { + fd.asyncHandler.Register(fd) + } + } + return fd.asyncHandler +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 1e66997ce..6b8b4ad49 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -56,6 +56,12 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err return linux.Statfs{}, syserror.ENOSYS } +// Allocate implements FileDescriptionImpl.Allocate analogously to +// fallocate called on regular file, directory or FIFO in Linux. +func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + // Readiness implements waiter.Waitable.Readiness analogously to // file_operations::poll == NULL in Linux. func (FileDescriptionDefaultImpl) Readiness(mask waiter.EventMask) waiter.EventMask { @@ -158,6 +164,11 @@ func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) // implementations of non-directory I/O methods that return EISDIR. type DirectoryFileDescriptionDefaultImpl struct{} +// Allocate implements DirectoryFileDescriptionDefaultImpl.Allocate. +func (DirectoryFileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EISDIR +} + // PRead implements FileDescriptionImpl.PRead. func (DirectoryFileDescriptionDefaultImpl) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.EISDIR @@ -327,7 +338,7 @@ func (fd *DynamicBytesFileDescriptionImpl) pwriteLocked(ctx context.Context, src writable, ok := fd.data.(WritableDynamicBytesSource) if !ok { - return 0, syserror.EINVAL + return 0, syserror.EIO } n, err := writable.Write(ctx, src, offset) if err != nil { diff --git a/pkg/sentry/vfs/file_description_impl_util_test.go b/pkg/sentry/vfs/file_description_impl_util_test.go index 5061f6ac9..3b7e1c273 100644 --- a/pkg/sentry/vfs/file_description_impl_util_test.go +++ b/pkg/sentry/vfs/file_description_impl_util_test.go @@ -155,11 +155,11 @@ func TestGenCountFD(t *testing.T) { } // Write and PWrite fails. - if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) + if _, err := fd.Write(ctx, ioseq, WriteOptions{}); err != syserror.EIO { + t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } - if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EINVAL { - t.Errorf("Write: got err %v, wanted %v", err, syserror.EINVAL) + if _, err := fd.PWrite(ctx, ioseq, 0, WriteOptions{}); err != syserror.EIO { + t.Errorf("Write: got err %v, wanted %v", err, syserror.EIO) } } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 1edd584c9..6bb9ca180 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -524,8 +524,6 @@ type FilesystemImpl interface { // // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error - - // TODO(gvisor.dev/issue/1479): inotify_add_watch() } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md new file mode 100644 index 000000000..e7da49faa --- /dev/null +++ b/pkg/sentry/vfs/g3doc/inotify.md @@ -0,0 +1,210 @@ +# Inotify + +Inotify is a mechanism for monitoring filesystem events in Linux--see +inotify(7). An inotify instance can be used to monitor files and directories for +modifications, creation/deletion, etc. The inotify API consists of system calls +that create inotify instances (inotify_init/inotify_init1) and add/remove +watches on files to an instance (inotify_add_watch/inotify_rm_watch). Events are +generated from various places in the sentry, including the syscall layer, the +vfs layer, the process fd table, and within each filesystem implementation. This +document outlines the implementation details of inotify in VFS2. + +## Inotify Objects + +Inotify data structures are implemented in the vfs package. + +### vfs.Inotify + +Inotify instances are represented by vfs.Inotify objects, which implement +vfs.FileDescriptionImpl. As in Linux, inotify fds are backed by a +pseudo-filesystem (anonfs). Each inotify instance receives events from a set of +vfs.Watch objects, which can be modified with inotify_add_watch(2) and +inotify_rm_watch(2). An application can retrieve events by reading the inotify +fd. + +### vfs.Watches + +The set of all watches held on a single file (i.e., the watch target) is stored +in vfs.Watches. Each watch will belong to a different inotify instance (an +instance can only have one watch on any watch target). The watches are stored in +a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions +to a single file will all share the same vfs.Watches. Activity on the target +causes its vfs.Watches to generate notifications on its watches’ inotify +instances. + +### vfs.Watch + +A single watch, owned by one inotify instance and applied to one watch target. +Both the vfs.Inotify owner and vfs.Watches on the target will hold a vfs.Watch, +which leads to some complicated locking behavior (see Lock Ordering). Whenever a +watch is notified of an event on its target, it will queue events to its inotify +instance for delivery to the user. + +### vfs.Event + +vfs.Event is a simple struct encapsulating all the fields for an inotify event. +It is generated by vfs.Watches and forwarded to the watches' owners. It is +serialized to the user during read(2) syscalls on the associated fs.Inotify's +fd. + +## Lock Ordering + +There are three locks related to the inotify implementation: + +Inotify.mu: the inotify instance lock. Inotify.evMu: the inotify event queue +lock. Watches.mu: the watch set lock, used to protect the collection of watches +on a target. + +The correct lock ordering for inotify code is: + +Inotify.mu -> Watches.mu -> Inotify.evMu. + +Note that we use a distinct lock to protect the inotify event queue. If we +simply used Inotify.mu, we could simultaneously have locks being acquired in the +order of Inotify.mu -> Watches.mu and Watches.mu -> Inotify.mu, which would +cause deadlocks. For instance, adding a watch to an inotify instance would +require locking Inotify.mu, and then adding the same watch to the target would +cause Watches.mu to be held. At the same time, generating an event on the target +would require Watches.mu to be held before iterating through each watch, and +then notifying the owner of each watch would cause Inotify.mu to be held. + +See the vfs package comment to understand how inotify locks fit into the overall +ordering of filesystem locks. + +## Watch Targets in Different Filesystem Implementations + +In Linux, watches reside on inodes at the virtual filesystem layer. As a result, +all hard links and file descriptions on a single file will all share the same +watch set. In VFS2, there is no common inode structure across filesystem types +(some may not even have inodes), so we have to plumb inotify support through +each specific filesystem implementation. Some of the technical considerations +are outlined below. + +### Tmpfs + +For filesystems with inodes, like tmpfs, the design is quite similar to that of +Linux, where watches reside on the inode. + +### Pseudo-filesystems + +Technically, because inotify is implemented at the vfs layer in Linux, +pseudo-filesystems on top of kernfs support inotify passively. However, watches +can only track explicit filesystem operations like read/write, open/close, +mknod, etc., so watches on a target like /proc/self/fd will not generate events +every time a new fd is added or removed. As of this writing, we leave inotify +unimplemented in kernfs and anonfs; it does not seem particularly useful. + +### Gofer Filesystem (fsimpl/gofer) + +The gofer filesystem has several traits that make it difficult to support +inotify: + +* **There are no inodes.** A file is represented as a dentry that holds an + unopened p9 file (and possibly an open FID), through which the Sentry + interacts with the gofer. + * *Solution:* Because there is no inode structure stored in the sandbox, + inotify watches must be held on the dentry. This would be an issue in + the presence of hard links, where multiple dentries would need to share + the same set of watches, but in VFS2, we do not support the internal + creation of hard links on gofer fs. As a result, we make the assumption + that every dentry corresponds to a unique inode. However, the next point + raises an issue with this assumption: +* **The Sentry cannot always be aware of hard links on the remote + filesystem.** There is no way for us to confirm whether two files on the + remote filesystem are actually links to the same inode. QIDs and inodes are + not always 1:1. The assumption that dentries and inodes are 1:1 is + inevitably broken if there are remote hard links that we cannot detect. + * *Solution:* this is an issue with gofer fs in general, not only inotify, + and we will have to live with it. +* **Dentries can be cached, and then evicted.** Dentry lifetime does not + correspond to file lifetime. Because gofer fs is not entirely in-memory, the + absence of a dentry does not mean that the corresponding file does not + exist, nor does a dentry reaching zero references mean that the + corresponding file no longer exists. When a dentry reaches zero references, + it will be cached, in case the file at that path is needed again in the + future. However, the dentry may be evicted from the cache, which will cause + a new dentry to be created next time the same file path is used. The + existing watches will be lost. + * *Solution:* When a dentry reaches zero references, do not cache it if it + has any watches, so we can avoid eviction/destruction. Note that if the + dentry was deleted or invalidated (d.vfsd.IsDead()), we should still + destroy it along with its watches. Additionally, when a dentry’s last + watch is removed, we cache it if it also has zero references. This way, + the dentry can eventually be evicted from memory if it is no longer + needed. +* **Dentries can be invalidated.** Another issue with dentry lifetime is that + the remote file at the file path represented may change from underneath the + dentry. In this case, the next time that the dentry is used, it will be + invalidated and a new dentry will replace it. In this case, it is not clear + what should be done with the watches on the old dentry. + * *Solution:* Silently destroy the watches when invalidation occurs. We + have no way of knowing exactly what happened, when it happens. Inotify + instances on NFS files in Linux probably behave in a similar fashion, + since inotify is implemented at the vfs layer and is not aware of the + complexities of remote file systems. + * An alternative would be to issue some kind of event upon invalidation, + e.g. a delete event, but this has several issues: + * We cannot discern whether the remote file was invalidated because it was + moved, deleted, etc. This information is crucial, because these cases + should result in different events. Furthermore, the watches should only + be destroyed if the file has been deleted. + * Moreover, the mechanism for detecting whether the underlying file has + changed is to check whether a new QID is given by the gofer. This may + result in false positives, e.g. suppose that the server closed and + re-opened the same file, which may result in a new QID. + * Finally, the time of the event may be completely different from the time + of the file modification, since a dentry is not immediately notified + when the underlying file has changed. It would be quite unexpected to + receive the notification when invalidation was triggered, i.e. the next + time the file was accessed within the sandbox, because then the + read/write/etc. operation on the file would not result in the expected + event. + * Another point in favor of the first solution: inotify in Linux can + already be lossy on local filesystems (one of the sacrifices made so + that filesystem performance isn’t killed), and it is lossy on NFS for + similar reasons to gofer fs. Therefore, it is better for inotify to be + silent than to emit incorrect notifications. +* **There may be external users of the remote filesystem.** We can only track + operations performed on the file within the sandbox. This is sufficient + under InteropModeExclusive, but whenever there are external users, the set + of actions we are aware of is incomplete. + * *Solution:* We could either return an error or just issue a warning when + inotify is used without InteropModeExclusive. Although faulty, VFS1 + allows it when the filesystem is shared, and Linux does the same for + remote filesystems (as mentioned above, inotify sits at the vfs level). + +## Dentry Interface + +For events that must be generated above the vfs layer, we provide the following +DentryImpl methods to allow interactions with targets on any FilesystemImpl: + +* **InotifyWithParent()** generates events on the dentry’s watches as well as + its parent’s. +* **Watches()** retrieves the watch set of the target represented by the + dentry. This is used to access and modify watches on a target. +* **OnZeroWatches()** performs cleanup tasks after the last watch is removed + from a dentry. This is needed by gofer fs, which must allow a watched dentry + to be cached once it has no more watches. Most implementations can just do + nothing. Note that OnZeroWatches() must be called after all inotify locks + are released to preserve lock ordering, since it may acquire + FilesystemImpl-specific locks. + +## IN_EXCL_UNLINK + +There are several options that can be set for a watch, specified as part of the +mask in inotify_add_watch(2). In particular, IN_EXCL_UNLINK requires some +additional support in each filesystem. + +A watch with IN_EXCL_UNLINK will not generate events for its target if it +corresponds to a path that was unlinked. For instance, if an fd is opened on +“foo/bar” and “foo/bar” is subsequently unlinked, any reads/writes/etc. on the +fd will be ignored by watches on “foo” or “foo/bar” with IN_EXCL_UNLINK. This +requires each DentryImpl to keep track of whether it has been unlinked, in order +to determine whether events should be sent to watches with IN_EXCL_UNLINK. + +## IN_ONESHOT + +One-shot watches expire after generating a single event. When an event occurs, +all one-shot watches on the target that successfully generated an event are +removed. Lock ordering can cause the management of one-shot watches to be quite +expensive; see Watches.Notify() for more information. diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 7fa7d2d0c..167b731ac 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -49,9 +49,6 @@ const ( // Inotify represents an inotify instance created by inotify_init(2) or // inotify_init1(2). Inotify implements FileDescriptionImpl. // -// Lock ordering: -// Inotify.mu -> Watches.mu -> Inotify.evMu -// // +stateify savable type Inotify struct { vfsfd FileDescription @@ -122,20 +119,40 @@ func NewInotifyFD(ctx context.Context, vfsObj *VirtualFilesystem, flags uint32) // Release implements FileDescriptionImpl.Release. Release removes all // watches and frees all resources for an inotify instance. func (i *Inotify) Release() { + var ds []*Dentry + // We need to hold i.mu to avoid a race with concurrent calls to // Inotify.handleDeletion from Watches. There's no risk of Watches // accessing this Inotify after the destructor ends, because we remove all // references to it below. i.mu.Lock() - defer i.mu.Unlock() for _, w := range i.watches { // Remove references to the watch from the watches set on the target. We // don't need to worry about the references from i.watches, since this // file description is about to be destroyed. - w.set.Remove(i.id) + d := w.target + ws := d.Watches() + // Watchable dentries should never return a nil watch set. + if ws == nil { + panic("Cannot remove watch from an unwatchable dentry") + } + ws.Remove(i.id) + if ws.Size() == 0 { + ds = append(ds, d) + } + } + i.mu.Unlock() + + for _, d := range ds { + d.OnZeroWatches() } } +// Allocate implements FileDescription.Allocate. +func (i *Inotify) Allocate(ctx context.Context, mode, offset, length uint64) error { + panic("Allocate should not be called on read-only inotify fds") +} + // EventRegister implements waiter.Waitable. func (i *Inotify) EventRegister(e *waiter.Entry, mask waiter.EventMask) { i.queue.EventRegister(e, mask) @@ -162,12 +179,12 @@ func (i *Inotify) Readiness(mask waiter.EventMask) waiter.EventMask { return mask & ready } -// PRead implements FileDescriptionImpl. +// PRead implements FileDescriptionImpl.PRead. func (*Inotify) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) { return 0, syserror.ESPIPE } -// PWrite implements FileDescriptionImpl. +// PWrite implements FileDescriptionImpl.PWrite. func (*Inotify) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) { return 0, syserror.ESPIPE } @@ -226,7 +243,7 @@ func (i *Inotify) Read(ctx context.Context, dst usermem.IOSequence, opts ReadOpt return writeLen, nil } -// Ioctl implements fs.FileOperations.Ioctl. +// Ioctl implements FileDescriptionImpl.Ioctl. func (i *Inotify) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch args[1].Int() { case linux.FIONREAD: @@ -272,20 +289,19 @@ func (i *Inotify) queueEvent(ev *Event) { // newWatchLocked creates and adds a new watch to target. // -// Precondition: i.mu must be locked. -func (i *Inotify) newWatchLocked(target *Dentry, mask uint32) *Watch { - targetWatches := target.Watches() +// Precondition: i.mu must be locked. ws must be the watch set for target d. +func (i *Inotify) newWatchLocked(d *Dentry, ws *Watches, mask uint32) *Watch { w := &Watch{ - owner: i, - wd: i.nextWatchIDLocked(), - set: targetWatches, - mask: mask, + owner: i, + wd: i.nextWatchIDLocked(), + target: d, + mask: mask, } // Hold the watch in this inotify instance as well as the watch set on the // target. i.watches[w.wd] = w - targetWatches.Add(w) + ws.Add(w) return w } @@ -297,22 +313,11 @@ func (i *Inotify) nextWatchIDLocked() int32 { return i.nextWatchMinusOne } -// handleDeletion handles the deletion of the target of watch w. It removes w -// from i.watches and a watch removal event is generated. -func (i *Inotify) handleDeletion(w *Watch) { - i.mu.Lock() - _, found := i.watches[w.wd] - delete(i.watches, w.wd) - i.mu.Unlock() - - if found { - i.queueEvent(newEvent(w.wd, "", linux.IN_IGNORED, 0)) - } -} - // AddWatch constructs a new inotify watch and adds it to the target. It // returns the watch descriptor returned by inotify_add_watch(2). -func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { +// +// The caller must hold a reference on target. +func (i *Inotify) AddWatch(target *Dentry, mask uint32) (int32, error) { // Note: Locking this inotify instance protects the result returned by // Lookup() below. With the lock held, we know for sure the lookup result // won't become stale because it's impossible for *this* instance to @@ -320,8 +325,14 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { i.mu.Lock() defer i.mu.Unlock() + ws := target.Watches() + if ws == nil { + // While Linux supports inotify watches on all filesystem types, watches on + // filesystems like kernfs are not generally useful, so we do not. + return 0, syserror.EPERM + } // Does the target already have a watch from this inotify instance? - if existing := target.Watches().Lookup(i.id); existing != nil { + if existing := ws.Lookup(i.id); existing != nil { newmask := mask if mask&linux.IN_MASK_ADD != 0 { // "Add (OR) events to watch mask for this pathname if it already @@ -329,12 +340,12 @@ func (i *Inotify) AddWatch(target *Dentry, mask uint32) int32 { newmask |= atomic.LoadUint32(&existing.mask) } atomic.StoreUint32(&existing.mask, newmask) - return existing.wd + return existing.wd, nil } // No existing watch, create a new watch. - w := i.newWatchLocked(target, mask) - return w.wd + w := i.newWatchLocked(target, ws, mask) + return w.wd, nil } // RmWatch looks up an inotify watch for the given 'wd' and configures the @@ -353,9 +364,19 @@ func (i *Inotify) RmWatch(wd int32) error { delete(i.watches, wd) // Remove the watch from the watch target. - w.set.Remove(w.OwnerID()) + ws := w.target.Watches() + // AddWatch ensures that w.target has a non-nil watch set. + if ws == nil { + panic("Watched dentry cannot have nil watch set") + } + ws.Remove(w.OwnerID()) + remaining := ws.Size() i.mu.Unlock() + if remaining == 0 { + w.target.OnZeroWatches() + } + // Generate the event for the removal. i.queueEvent(newEvent(wd, "", linux.IN_IGNORED, 0)) @@ -374,6 +395,13 @@ type Watches struct { ws map[uint64]*Watch } +// Size returns the number of watches held by w. +func (w *Watches) Size() int { + w.mu.Lock() + defer w.mu.Unlock() + return len(w.ws) +} + // Lookup returns the watch owned by an inotify instance with the given id. // Returns nil if no such watch exists. // @@ -424,64 +452,86 @@ func (w *Watches) Remove(id uint64) { return } - if _, ok := w.ws[id]; !ok { - // While there's technically no problem with silently ignoring a missing - // watch, this is almost certainly a bug. - panic(fmt.Sprintf("Attempt to remove a watch, but no watch found with provided id %+v.", id)) + // It is possible for w.Remove() to be called for the same watch multiple + // times. See the treatment of one-shot watches in Watches.Notify(). + if _, ok := w.ws[id]; ok { + delete(w.ws, id) } - delete(w.ws, id) } -// Notify queues a new event with all watches in this set. -func (w *Watches) Notify(name string, events, cookie uint32, et EventType) { - w.NotifyWithExclusions(name, events, cookie, et, false) +// Notify queues a new event with watches in this set. Watches with +// IN_EXCL_UNLINK are skipped if the event is coming from a child that has been +// unlinked. +func (w *Watches) Notify(name string, events, cookie uint32, et EventType, unlinked bool) { + var hasExpired bool + w.mu.RLock() + for _, watch := range w.ws { + if unlinked && watch.ExcludeUnlinked() && et == PathEvent { + continue + } + if watch.Notify(name, events, cookie) { + hasExpired = true + } + } + w.mu.RUnlock() + + if hasExpired { + w.cleanupExpiredWatches() + } } -// NotifyWithExclusions queues a new event with watches in this set. Watches -// with IN_EXCL_UNLINK are skipped if the event is coming from a child that -// has been unlinked. -func (w *Watches) NotifyWithExclusions(name string, events, cookie uint32, et EventType, unlinked bool) { - // N.B. We don't defer the unlocks because Notify is in the hot path of - // all IO operations, and the defer costs too much for small IO - // operations. +// This function is relatively expensive and should only be called where there +// are expired watches. +func (w *Watches) cleanupExpiredWatches() { + // Because of lock ordering, we cannot acquire Inotify.mu for each watch + // owner while holding w.mu. As a result, store expired watches locally + // before removing. + var toRemove []*Watch w.mu.RLock() for _, watch := range w.ws { - if unlinked && watch.ExcludeUnlinkedChildren() && et == PathEvent { - continue + if atomic.LoadInt32(&watch.expired) == 1 { + toRemove = append(toRemove, watch) } - watch.Notify(name, events, cookie) } w.mu.RUnlock() + for _, watch := range toRemove { + watch.owner.RmWatch(watch.wd) + } } -// HandleDeletion is called when the watch target is destroyed to emit -// the appropriate events. +// HandleDeletion is called when the watch target is destroyed. Clear the +// watch set, detach watches from the inotify instances they belong to, and +// generate the appropriate events. func (w *Watches) HandleDeletion() { - w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent) + w.Notify("", linux.IN_DELETE_SELF, 0, InodeEvent, true /* unlinked */) - // TODO(gvisor.dev/issue/1479): This doesn't work because maps are not copied - // by value. Ideally, we wouldn't have this circular locking so we can just - // notify of IN_DELETE_SELF in the same loop below. - // - // We can't hold w.mu while calling watch.handleDeletion to preserve lock - // ordering w.r.t to the owner inotify instances. Instead, atomically move - // the watches map into a local variable so we can iterate over it safely. - // - // Because of this however, it is possible for the watches' owners to reach - // this inode while the inode has no refs. This is still safe because the - // owners can only reach the inode until this function finishes calling - // watch.handleDeletion below and the inode is guaranteed to exist in the - // meantime. But we still have to be very careful not to rely on inode state - // that may have been already destroyed. + // As in Watches.Notify, we can't hold w.mu while acquiring Inotify.mu for + // the owner of each watch being deleted. Instead, atomically store the + // watches map in a local variable and set it to nil so we can iterate over + // it with the assurance that there will be no concurrent accesses. var ws map[uint64]*Watch w.mu.Lock() ws = w.ws w.ws = nil w.mu.Unlock() + // Remove each watch from its owner's watch set, and generate a corresponding + // watch removal event. for _, watch := range ws { - // TODO(gvisor.dev/issue/1479): consider refactoring this. - watch.handleDeletion() + i := watch.owner + i.mu.Lock() + _, found := i.watches[watch.wd] + delete(i.watches, watch.wd) + + // Release mutex before notifying waiters because we don't control what + // they can do. + i.mu.Unlock() + + // If watch was not found, it was removed from the inotify instance before + // we could get to it, in which case we should not generate an event. + if found { + i.queueEvent(newEvent(watch.wd, "", linux.IN_IGNORED, 0)) + } } } @@ -490,18 +540,28 @@ func (w *Watches) HandleDeletion() { // +stateify savable type Watch struct { // Inotify instance which owns this watch. + // + // This field is immutable after creation. owner *Inotify // Descriptor for this watch. This is unique across an inotify instance. + // + // This field is immutable after creation. wd int32 - // set is the watch set containing this watch. It belongs to the target file - // of this watch. - set *Watches + // target is a dentry representing the watch target. Its watch set contains this watch. + // + // This field is immutable after creation. + target *Dentry // Events being monitored via this watch. Must be accessed with atomic // memory operations. mask uint32 + + // expired is set to 1 to indicate that this watch is a one-shot that has + // already sent a notification and therefore can be removed. Must be accessed + // with atomic memory operations. + expired int32 } // OwnerID returns the id of the inotify instance that owns this watch. @@ -509,23 +569,29 @@ func (w *Watch) OwnerID() uint64 { return w.owner.id } -// ExcludeUnlinkedChildren indicates whether the watched object should continue -// to be notified of events of its children after they have been unlinked, e.g. -// for an open file descriptor. +// ExcludeUnlinked indicates whether the watched object should continue to be +// notified of events originating from a path that has been unlinked. // -// TODO(gvisor.dev/issue/1479): Implement IN_EXCL_UNLINK. -// We can do this by keeping track of the set of unlinked children in Watches -// to skip notification. -func (w *Watch) ExcludeUnlinkedChildren() bool { +// For example, if "foo/bar" is opened and then unlinked, operations on the +// open fd may be ignored by watches on "foo" and "foo/bar" with IN_EXCL_UNLINK. +func (w *Watch) ExcludeUnlinked() bool { return atomic.LoadUint32(&w.mask)&linux.IN_EXCL_UNLINK != 0 } -// Notify queues a new event on this watch. -func (w *Watch) Notify(name string, events uint32, cookie uint32) { +// Notify queues a new event on this watch. Returns true if this is a one-shot +// watch that should be deleted, after this event was successfully queued. +func (w *Watch) Notify(name string, events uint32, cookie uint32) bool { + if atomic.LoadInt32(&w.expired) == 1 { + // This is a one-shot watch that is already in the process of being + // removed. This may happen if a second event reaches the watch target + // before this watch has been removed. + return false + } + mask := atomic.LoadUint32(&w.mask) if mask&events == 0 { // We weren't watching for this event. - return + return false } // Event mask should include bits matched from the watch plus all control @@ -534,11 +600,11 @@ func (w *Watch) Notify(name string, events uint32, cookie uint32) { effectiveMask := unmaskableBits | mask matchedEvents := effectiveMask & events w.owner.queueEvent(newEvent(w.wd, name, matchedEvents, cookie)) -} - -// handleDeletion handles the deletion of w's target. -func (w *Watch) handleDeletion() { - w.owner.handleDeletion(w) + if mask&linux.IN_ONESHOT != 0 { + atomic.StoreInt32(&w.expired, 1) + return true + } + return false } // Event represents a struct inotify_event from linux. @@ -606,7 +672,7 @@ func (e *Event) setName(name string) { func (e *Event) sizeOf() int { s := inotifyEventBaseSize + int(e.len) if s < inotifyEventBaseSize { - panic("overflow") + panic("Overflowed event size") } return s } @@ -676,11 +742,15 @@ func InotifyEventFromStatMask(mask uint32) uint32 { } // InotifyRemoveChild sends the appriopriate notifications to the watch sets of -// the child being removed and its parent. +// the child being removed and its parent. Note that unlike most pairs of +// parent/child notifications, the child is notified first in this case. func InotifyRemoveChild(self, parent *Watches, name string) { - self.Notify("", linux.IN_ATTRIB, 0, InodeEvent) - parent.Notify(name, linux.IN_DELETE, 0, InodeEvent) - // TODO(gvisor.dev/issue/1479): implement IN_EXCL_UNLINK. + if self != nil { + self.Notify("", linux.IN_ATTRIB, 0, InodeEvent, true /* unlinked */) + } + if parent != nil { + parent.Notify(name, linux.IN_DELETE, 0, InodeEvent, true /* unlinked */) + } } // InotifyRename sends the appriopriate notifications to the watch sets of the @@ -691,8 +761,14 @@ func InotifyRename(ctx context.Context, renamed, oldParent, newParent *Watches, dirEv = linux.IN_ISDIR } cookie := uniqueid.InotifyCookie(ctx) - oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent) - newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent) + if oldParent != nil { + oldParent.Notify(oldName, dirEv|linux.IN_MOVED_FROM, cookie, InodeEvent, false /* unlinked */) + } + if newParent != nil { + newParent.Notify(newName, dirEv|linux.IN_MOVED_TO, cookie, InodeEvent, false /* unlinked */) + } // Somewhat surprisingly, self move events do not have a cookie. - renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent) + if renamed != nil { + renamed.Notify("", linux.IN_MOVE_SELF, 0, InodeEvent, false /* unlinked */) + } } diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index f223aeda8..dfc8573fd 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -79,6 +79,17 @@ type MountFlags struct { // NoATime is equivalent to MS_NOATIME and indicates that the // filesystem should not update access time in-place. NoATime bool + + // NoDev is equivalent to MS_NODEV and indicates that the + // filesystem should not allow access to devices (special files). + // TODO(gVisor.dev/issue/3186): respect this flag in non FUSE + // filesystems. + NoDev bool + + // NoSUID is equivalent to MS_NOSUID and indicates that the + // filesystem should not honor set-user-ID and set-group-ID bits or + // file capabilities when executing programs. + NoSUID bool } // MountOptions contains options to VirtualFilesystem.MountAt(). @@ -153,6 +164,12 @@ type SetStatOptions struct { // == UTIME_OMIT (VFS users must unset the corresponding bit in Stat.Mask // instead). Stat linux.Statx + + // NeedWritePerm indicates that write permission on the file is needed for + // this operation. This is needed for truncate(2) (note that ftruncate(2) + // does not require the same check--instead, it checks that the fd is + // writable). + NeedWritePerm bool } // BoundEndpointOptions contains options to VirtualFilesystem.BoundEndpointAt() diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index f9647f90e..33389c1df 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -94,6 +94,37 @@ func GenericCheckPermissions(creds *auth.Credentials, ats AccessTypes, mode linu return syserror.EACCES } +// MayLink determines whether creating a hard link to a file with the given +// mode, kuid, and kgid is permitted. +// +// This corresponds to Linux's fs/namei.c:may_linkat. +func MayLink(creds *auth.Credentials, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + // Source inode owner can hardlink all they like; otherwise, it must be a + // safe source. + if CanActAsOwner(creds, kuid) { + return nil + } + + // Only regular files can be hard linked. + if mode.FileType() != linux.S_IFREG { + return syserror.EPERM + } + + // Setuid files should not get pinned to the filesystem. + if mode&linux.S_ISUID != 0 { + return syserror.EPERM + } + + // Executable setgid files should not get pinned to the filesystem, but we + // don't support S_IXGRP anyway. + + // Hardlinking to unreadable or unwritable sources is dangerous. + if err := GenericCheckPermissions(creds, MayRead|MayWrite, mode, kuid, kgid); err != nil { + return syserror.EPERM + } + return nil +} + // AccessTypesForOpenFlags returns the access types required to open a file // with the given OpenOptions.Flags. Note that this is NOT the same thing as // the set of accesses permitted for the opened file: @@ -152,7 +183,8 @@ func MayWriteFileWithOpenFlags(flags uint32) bool { // CheckSetStat checks that creds has permission to change the metadata of a // file with the given permissions, UID, and GID as specified by stat, subject // to the rules of Linux's fs/attr.c:setattr_prepare(). -func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Statx, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { +func CheckSetStat(ctx context.Context, creds *auth.Credentials, opts *SetStatOptions, mode linux.FileMode, kuid auth.KUID, kgid auth.KGID) error { + stat := &opts.Stat if stat.Mask&linux.STATX_SIZE != 0 { limit, err := CheckLimit(ctx, 0, int64(stat.Size)) if err != nil { @@ -184,6 +216,11 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return syserror.EPERM } } + if opts.NeedWritePerm && !creds.HasCapability(linux.CAP_DAC_OVERRIDE) { + if err := GenericCheckPermissions(creds, MayWrite, mode, kuid, kgid); err != nil { + return err + } + } if stat.Mask&(linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_CTIME) != 0 { if !CanActAsOwner(creds, kuid) { if (stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec != linux.UTIME_NOW) || @@ -199,6 +236,20 @@ func CheckSetStat(ctx context.Context, creds *auth.Credentials, stat *linux.Stat return nil } +// CheckDeleteSticky checks whether the sticky bit is set on a directory with +// the given file mode, and if so, checks whether creds has permission to +// remove a file owned by childKUID from a directory with the given mode. +// CheckDeleteSticky is consistent with fs/linux.h:check_sticky(). +func CheckDeleteSticky(creds *auth.Credentials, parentMode linux.FileMode, childKUID auth.KUID) error { + if parentMode&linux.ModeSticky == 0 { + return nil + } + if CanActAsOwner(creds, childKUID) { + return nil + } + return syserror.EPERM +} + // CanActAsOwner returns true if creds can act as the owner of a file with the // given owning UID, consistent with Linux's // fs/inode.c:inode_owner_or_capable(). diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 9acca8bc7..522e27475 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -24,6 +24,9 @@ // Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry // VirtualFilesystem.filesystemsMu // EpollInstance.mu +// Inotify.mu +// Watches.mu +// Inotify.evMu // VirtualFilesystem.fsTypesMu // // Locking Dentry.mu in multiple Dentries requires holding @@ -120,6 +123,9 @@ type VirtualFilesystem struct { // Init initializes a new VirtualFilesystem with no mounts or FilesystemTypes. func (vfs *VirtualFilesystem) Init() error { + if vfs.mountpoints != nil { + panic("VFS already initialized") + } vfs.mountpoints = make(map[*Dentry]map[*Mount]struct{}) vfs.devices = make(map[devTuple]*registeredDevice) vfs.anonBlockDevMinorNext = 1 diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD new file mode 100644 index 000000000..f08599ebd --- /dev/null +++ b/pkg/shim/runsc/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "runsc", + srcs = [ + "runsc.go", + "utils.go", + ], + visibility = ["//:sandbox"], + deps = [ + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go new file mode 100644 index 000000000..c5cf68efa --- /dev/null +++ b/pkg/shim/runsc/runsc.go @@ -0,0 +1,514 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "strconv" + "syscall" + "time" + + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +var Monitor runc.ProcessMonitor = runc.Monitor + +// DefaultCommand is the default command for Runsc. +const DefaultCommand = "runsc" + +// Runsc is the client to the runsc cli. +type Runsc struct { + Command string + PdeathSignal syscall.Signal + Setpgid bool + Root string + Log string + LogFormat runc.Format + Config map[string]string +} + +// List returns all containers created inside the provided runsc root directory. +func (r *Runsc) List(context context.Context) ([]*runc.Container, error) { + data, err := cmdOutput(r.command(context, "list", "--format=json"), false) + if err != nil { + return nil, err + } + var out []*runc.Container + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return out, nil +} + +// State returns the state for the container provided by id. +func (r *Runsc) State(context context.Context, id string) (*runc.Container, error) { + data, err := cmdOutput(r.command(context, "state", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var c runc.Container + if err := json.Unmarshal(data, &c); err != nil { + return nil, err + } + return &c, nil +} + +type CreateOpts struct { + runc.IO + ConsoleSocket runc.ConsoleSocket + + // PidFile is a path to where a pid file should be created. + PidFile string + + // UserLog is a path to where runsc user log should be generated. + UserLog string +} + +func (o *CreateOpts) args() (out []string, err error) { + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.UserLog != "" { + out = append(out, "--user-log", o.UserLog) + } + return out, nil +} + +// Create creates a new container and returns its pid if it was created successfully. +func (r *Runsc) Create(context context.Context, id, bundle string, opts *CreateOpts) error { + args := []string{"create", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +// Start will start an already created container. +func (r *Runsc) Start(context context.Context, id string, cio runc.IO) error { + cmd := r.command(context, "start", id) + if cio != nil { + cio.Set(cmd) + } + + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if cio != nil { + if c, ok := cio.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return err +} + +type waitResult struct { + ID string `json:"id"` + ExitStatus int `json:"exitStatus"` +} + +// Wait will wait for a running container, and return its exit status. +// +// TODO(random-liu): Add exec process support. +func (r *Runsc) Wait(context context.Context, id string) (int, error) { + data, err := cmdOutput(r.command(context, "wait", id), true) + if err != nil { + return 0, fmt.Errorf("%s: %s", err, data) + } + var res waitResult + if err := json.Unmarshal(data, &res); err != nil { + return 0, err + } + return res.ExitStatus, nil +} + +type ExecOpts struct { + runc.IO + PidFile string + InternalPidFile string + ConsoleSocket runc.ConsoleSocket + Detach bool +} + +func (o *ExecOpts) args() (out []string, err error) { + if o.ConsoleSocket != nil { + out = append(out, "--console-socket", o.ConsoleSocket.Path()) + } + if o.Detach { + out = append(out, "--detach") + } + if o.PidFile != "" { + abs, err := filepath.Abs(o.PidFile) + if err != nil { + return nil, err + } + out = append(out, "--pid-file", abs) + } + if o.InternalPidFile != "" { + abs, err := filepath.Abs(o.InternalPidFile) + if err != nil { + return nil, err + } + out = append(out, "--internal-pid-file", abs) + } + return out, nil +} + +// Exec executes an additional process inside the container based on a full OCI +// Process specification. +func (r *Runsc) Exec(context context.Context, id string, spec specs.Process, opts *ExecOpts) error { + f, err := ioutil.TempFile(os.Getenv("XDG_RUNTIME_DIR"), "runsc-process") + if err != nil { + return err + } + defer os.Remove(f.Name()) + err = json.NewEncoder(f).Encode(spec) + f.Close() + if err != nil { + return err + } + args := []string{"exec", "--process", f.Name()} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + if cmd.Stdout == nil && cmd.Stderr == nil { + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil + } + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + if opts != nil && opts.IO != nil { + if c, ok := opts.IO.(runc.StartCloser); ok { + if err := c.CloseAfterStart(); err != nil { + return err + } + } + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err +} + +// Run runs the create, start, delete lifecycle of the container and returns +// its exit status after it has exited. +func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts) (int, error) { + args := []string{"run", "--bundle", bundle} + if opts != nil { + oargs, err := opts.args() + if err != nil { + return -1, err + } + args = append(args, oargs...) + } + cmd := r.command(context, append(args, id)...) + if opts != nil && opts.IO != nil { + opts.Set(cmd) + } + ec, err := Monitor.Start(cmd) + if err != nil { + return -1, err + } + return Monitor.Wait(cmd, ec) +} + +type DeleteOpts struct { + Force bool +} + +func (o *DeleteOpts) args() (out []string) { + if o.Force { + out = append(out, "--force") + } + return out +} + +// Delete deletes the container. +func (r *Runsc) Delete(context context.Context, id string, opts *DeleteOpts) error { + args := []string{"delete"} + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id)...)) +} + +// KillOpts specifies options for killing a container and its processes. +type KillOpts struct { + All bool + Pid int +} + +func (o *KillOpts) args() (out []string) { + if o.All { + out = append(out, "--all") + } + if o.Pid != 0 { + out = append(out, "--pid", strconv.Itoa(o.Pid)) + } + return out +} + +// Kill sends the specified signal to the container. +func (r *Runsc) Kill(context context.Context, id string, sig int, opts *KillOpts) error { + args := []string{ + "kill", + } + if opts != nil { + args = append(args, opts.args()...) + } + return r.runOrError(r.command(context, append(args, id, strconv.Itoa(sig))...)) +} + +// Stats return the stats for a container like cpu, memory, and I/O. +func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { + cmd := r.command(context, "events", "--stats", id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + defer func() { + rd.Close() + Monitor.Wait(cmd, ec) + }() + var e runc.Event + if err := json.NewDecoder(rd).Decode(&e); err != nil { + return nil, err + } + return e.Stats, nil +} + +// Events returns an event stream from runsc for a container with stats and OOM notifications. +func (r *Runsc) Events(context context.Context, id string, interval time.Duration) (chan *runc.Event, error) { + cmd := r.command(context, "events", fmt.Sprintf("--interval=%ds", int(interval.Seconds())), id) + rd, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + ec, err := Monitor.Start(cmd) + if err != nil { + rd.Close() + return nil, err + } + var ( + dec = json.NewDecoder(rd) + c = make(chan *runc.Event, 128) + ) + go func() { + defer func() { + close(c) + rd.Close() + Monitor.Wait(cmd, ec) + }() + for { + var e runc.Event + if err := dec.Decode(&e); err != nil { + if err == io.EOF { + return + } + e = runc.Event{ + Type: "error", + Err: err, + } + } + c <- &e + } + }() + return c, nil +} + +// Ps lists all the processes inside the container returning their pids. +func (r *Runsc) Ps(context context.Context, id string) ([]int, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "json", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + var pids []int + if err := json.Unmarshal(data, &pids); err != nil { + return nil, err + } + return pids, nil +} + +// Top lists all the processes inside the container returning the full ps data. +func (r *Runsc) Top(context context.Context, id string) (*runc.TopResults, error) { + data, err := cmdOutput(r.command(context, "ps", "--format", "table", id), true) + if err != nil { + return nil, fmt.Errorf("%s: %s", err, data) + } + + topResults, err := runc.ParsePSOutput(data) + if err != nil { + return nil, fmt.Errorf("%s: ", err) + } + return topResults, nil +} + +func (r *Runsc) args() []string { + var args []string + if r.Root != "" { + args = append(args, fmt.Sprintf("--root=%s", r.Root)) + } + if r.Log != "" { + args = append(args, fmt.Sprintf("--log=%s", r.Log)) + } + if r.LogFormat != "" { + args = append(args, fmt.Sprintf("--log-format=%s", r.LogFormat)) + } + for k, v := range r.Config { + args = append(args, fmt.Sprintf("--%s=%s", k, v)) + } + return args +} + +// runOrError will run the provided command. +// +// If an error is encountered and neither Stdout or Stderr was set the error +// will be returned in the format of <error>: <stderr>. +func (r *Runsc) runOrError(cmd *exec.Cmd) error { + if cmd.Stdout != nil || cmd.Stderr != nil { + ec, err := Monitor.Start(cmd) + if err != nil { + return err + } + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + return err + } + data, err := cmdOutput(cmd, true) + if err != nil { + return fmt.Errorf("%s: %s", err, data) + } + return nil +} + +func (r *Runsc) command(context context.Context, args ...string) *exec.Cmd { + command := r.Command + if command == "" { + command = DefaultCommand + } + cmd := exec.CommandContext(context, command, append(r.args(), args...)...) + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: r.Setpgid, + } + if r.PdeathSignal != 0 { + cmd.SysProcAttr.Pdeathsig = r.PdeathSignal + } + + return cmd +} + +func cmdOutput(cmd *exec.Cmd, combined bool) ([]byte, error) { + b := getBuf() + defer putBuf(b) + + cmd.Stdout = b + if combined { + cmd.Stderr = b + } + ec, err := Monitor.Start(cmd) + if err != nil { + return nil, err + } + + status, err := Monitor.Wait(cmd, ec) + if err == nil && status != 0 { + err = fmt.Errorf("%s did not terminate sucessfully", cmd.Args[0]) + } + + return b.Bytes(), err +} diff --git a/pkg/shim/runsc/utils.go b/pkg/shim/runsc/utils.go new file mode 100644 index 000000000..c514b3bc7 --- /dev/null +++ b/pkg/shim/runsc/utils.go @@ -0,0 +1,44 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runsc + +import ( + "bytes" + "strings" + "sync" +) + +var bytesBufferPool = sync.Pool{ + New: func() interface{} { + return bytes.NewBuffer(nil) + }, +} + +func getBuf() *bytes.Buffer { + return bytesBufferPool.Get().(*bytes.Buffer) +} + +func putBuf(b *bytes.Buffer) { + b.Reset() + bytesBufferPool.Put(b) +} + +// FormatLogPath parses runsc config, and fill in %ID% in the log path. +func FormatLogPath(id string, config map[string]string) { + if path, ok := config["debug-log"]; ok { + config["debug-log"] = strings.Replace(path, "%ID%", id, -1) + } +} diff --git a/pkg/shim/v1/proc/BUILD b/pkg/shim/v1/proc/BUILD new file mode 100644 index 000000000..4377306af --- /dev/null +++ b/pkg/shim/v1/proc/BUILD @@ -0,0 +1,36 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "proc", + srcs = [ + "deleted_state.go", + "exec.go", + "exec_state.go", + "init.go", + "init_state.go", + "io.go", + "process.go", + "types.go", + "utils.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_go_runc//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v1/proc/deleted_state.go b/pkg/shim/v1/proc/deleted_state.go new file mode 100644 index 000000000..d9b970c4d --- /dev/null +++ b/pkg/shim/v1/proc/deleted_state.go @@ -0,0 +1,49 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type deletedState struct{} + +func (*deletedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a deleted process.ss") +} + +func (*deletedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a deleted process.ss") +} + +func (*deletedState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) Kill(ctx context.Context, sig uint32, all bool) error { + return fmt.Errorf("cannot kill a deleted process.ss: %w", errdefs.ErrNotFound) +} + +func (*deletedState) SetExited(status int) {} + +func (*deletedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a deleted state") +} diff --git a/pkg/shim/v1/proc/exec.go b/pkg/shim/v1/proc/exec.go new file mode 100644 index 000000000..1d1d90488 --- /dev/null +++ b/pkg/shim/v1/proc/exec.go @@ -0,0 +1,281 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +type execProcess struct { + wg sync.WaitGroup + + execState execState + + mu sync.Mutex + id string + console console.Console + io runc.IO + status int + exited time.Time + pid int + internalPid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + path string + spec specs.Process + + parent *Init + waitBlock chan struct{} +} + +func (e *execProcess) Wait() { + <-e.waitBlock +} + +func (e *execProcess) ID() string { + return e.id +} + +func (e *execProcess) Pid() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.pid +} + +func (e *execProcess) ExitStatus() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.status +} + +func (e *execProcess) ExitedAt() time.Time { + e.mu.Lock() + defer e.mu.Unlock() + return e.exited +} + +func (e *execProcess) SetExited(status int) { + e.mu.Lock() + defer e.mu.Unlock() + + e.execState.SetExited(status) +} + +func (e *execProcess) setExited(status int) { + e.status = status + e.exited = time.Now() + e.parent.Platform.ShutdownConsole(context.Background(), e.console) + close(e.waitBlock) +} + +func (e *execProcess) Delete(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Delete(ctx) +} + +func (e *execProcess) delete(ctx context.Context) error { + e.wg.Wait() + if e.io != nil { + for _, c := range e.closers { + c.Close() + } + e.io.Close() + } + pidfile := filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + // silently ignore error + os.Remove(pidfile) + internalPidfile := filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + // silently ignore error + os.Remove(internalPidfile) + return nil +} + +func (e *execProcess) Resize(ws console.WinSize) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Resize(ws) +} + +func (e *execProcess) resize(ws console.WinSize) error { + if e.console == nil { + return nil + } + return e.console.Resize(ws) +} + +func (e *execProcess) Kill(ctx context.Context, sig uint32, _ bool) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Kill(ctx, sig, false) +} + +func (e *execProcess) kill(ctx context.Context, sig uint32, _ bool) error { + internalPid := e.internalPid + if internalPid != 0 { + if err := e.parent.runtime.Kill(ctx, e.parent.id, int(sig), &runsc.KillOpts{ + Pid: internalPid, + }); err != nil { + // If this returns error, consider the process has + // already stopped. + // + // TODO: Fix after signal handling is fixed. + return fmt.Errorf("%s: %w", err.Error(), errdefs.ErrNotFound) + } + } + return nil +} + +func (e *execProcess) Stdin() io.Closer { + return e.stdin +} + +func (e *execProcess) Stdio() stdio.Stdio { + return e.stdio +} + +func (e *execProcess) Start(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + return e.execState.Start(ctx) +} + +func (e *execProcess) start(ctx context.Context) (err error) { + var ( + socket *runc.Socket + pidfile = filepath.Join(e.path, fmt.Sprintf("%s.pid", e.id)) + internalPidfile = filepath.Join(e.path, fmt.Sprintf("%s-internal.pid", e.id)) + ) + if e.stdio.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create runc console socket: %w", err) + } + defer socket.Close() + } else if e.stdio.IsNull() { + if e.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if e.io, err = runc.NewPipeIO(e.parent.IoUID, e.parent.IoGID, withConditionalIO(e.stdio)); err != nil { + return fmt.Errorf("failed to create runc io pipes: %w", err) + } + } + opts := &runsc.ExecOpts{ + PidFile: pidfile, + InternalPidFile: internalPidfile, + IO: e.io, + Detach: true, + } + if socket != nil { + opts.ConsoleSocket = socket + } + eventCh := e.parent.Monitor.Subscribe() + defer func() { + // Unsubscribe if an error is returned. + if err != nil { + e.parent.Monitor.Unsubscribe(eventCh) + } + }() + if err := e.parent.runtime.Exec(ctx, e.parent.id, e.spec, opts); err != nil { + close(e.waitBlock) + return e.parent.runtimeError(err, "OCI runtime exec failed") + } + if e.stdio.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), e.stdio.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", e.stdio.Stdin, err) + } + e.closers = append(e.closers, sc) + e.stdin = sc + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + if e.console, err = e.parent.Platform.CopyConsole(ctx, console, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + } else if !e.stdio.IsNull() { + if err := copyPipes(ctx, e.io, e.stdio.Stdin, e.stdio.Stdout, e.stdio.Stderr, &e.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(opts.PidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec pid: %w", err) + } + e.pid = pid + internalPid, err := runc.ReadPidFile(opts.InternalPidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime exec internal pid: %w", err) + } + e.internalPid = internalPid + go func() { + defer e.parent.Monitor.Unsubscribe(eventCh) + for event := range eventCh { + if event.Pid == e.pid { + ExitCh <- Exit{ + Timestamp: event.Timestamp, + ID: e.id, + Status: event.Status, + } + break + } + } + }() + return nil +} + +func (e *execProcess) Status(ctx context.Context) (string, error) { + e.mu.Lock() + defer e.mu.Unlock() + // if we don't have a pid then the exec process has just been created + if e.pid == 0 { + return "created", nil + } + // if we have a pid and it can be signaled, the process is running + // TODO(random-liu): Use `runsc kill --pid`. + if err := unix.Kill(e.pid, 0); err == nil { + return "running", nil + } + // else if we have a pid but it can nolonger be signaled, it has stopped + return "stopped", nil +} diff --git a/pkg/shim/v1/proc/exec_state.go b/pkg/shim/v1/proc/exec_state.go new file mode 100644 index 000000000..4dcda8b44 --- /dev/null +++ b/pkg/shim/v1/proc/exec_state.go @@ -0,0 +1,154 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" +) + +type execState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type execCreatedState struct { + p *execProcess +} + +func (s *execCreatedState) transition(name string) error { + switch name { + case "running": + s.p.execState = &execRunningState{p: s.p} + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execCreatedState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execCreatedState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + return err + } + return s.transition("running") +} + +func (s *execCreatedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execCreatedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execCreatedState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execRunningState struct { + p *execProcess +} + +func (s *execRunningState) transition(name string) error { + switch name { + case "stopped": + s.p.execState = &execStoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execRunningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *execRunningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process") +} + +func (s *execRunningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process") +} + +func (s *execRunningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execRunningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +type execStoppedState struct { + p *execProcess +} + +func (s *execStoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.execState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *execStoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *execStoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process") +} + +func (s *execStoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *execStoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *execStoppedState) SetExited(status int) { + // no op +} diff --git a/pkg/shim/v1/proc/init.go b/pkg/shim/v1/proc/init.go new file mode 100644 index 000000000..dab3123d6 --- /dev/null +++ b/pkg/shim/v1/proc/init.go @@ -0,0 +1,460 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "path/filepath" + "strings" + "sync" + "syscall" + "time" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" + specs "github.com/opencontainers/runtime-spec/specs-go" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +// InitPidFile name of the file that contains the init pid. +const InitPidFile = "init.pid" + +// Init represents an initial process for a container. +type Init struct { + wg sync.WaitGroup + initState initState + + // mu is used to ensure that `Start()` and `Exited()` calls return in + // the right order when invoked in separate go routines. This is the + // case within the shim implementation as it makes use of the reaper + // interface. + mu sync.Mutex + + waitBlock chan struct{} + + WorkDir string + + id string + Bundle string + console console.Console + Platform stdio.Platform + io runc.IO + runtime *runsc.Runsc + status int + exited time.Time + pid int + closers []io.Closer + stdin io.Closer + stdio stdio.Stdio + Rootfs string + IoUID int + IoGID int + Sandbox bool + UserLog string + Monitor ProcessMonitor +} + +// NewRunsc returns a new runsc instance for a process. +func NewRunsc(root, path, namespace, runtime string, config map[string]string) *runsc.Runsc { + if root == "" { + root = RunscRoot + } + return &runsc.Runsc{ + Command: runtime, + PdeathSignal: syscall.SIGKILL, + Log: filepath.Join(path, "log.json"), + LogFormat: runc.JSON, + Root: filepath.Join(root, namespace), + Config: config, + } +} + +// New returns a new init process. +func New(id string, runtime *runsc.Runsc, stdio stdio.Stdio) *Init { + p := &Init{ + id: id, + runtime: runtime, + stdio: stdio, + status: 0, + waitBlock: make(chan struct{}), + } + p.initState = &createdState{p: p} + return p +} + +// Create the process with the provided config. +func (p *Init) Create(ctx context.Context, r *CreateConfig) (err error) { + var socket *runc.Socket + if r.Terminal { + if socket, err = runc.NewTempConsoleSocket(); err != nil { + return fmt.Errorf("failed to create OCI runtime console socket: %w", err) + } + defer socket.Close() + } else if hasNoIO(r) { + if p.io, err = runc.NewNullIO(); err != nil { + return fmt.Errorf("creating new NULL IO: %w", err) + } + } else { + if p.io, err = runc.NewPipeIO(p.IoUID, p.IoGID, withConditionalIO(p.stdio)); err != nil { + return fmt.Errorf("failed to create OCI runtime io pipes: %w", err) + } + } + pidFile := filepath.Join(p.Bundle, InitPidFile) + opts := &runsc.CreateOpts{ + PidFile: pidFile, + } + if socket != nil { + opts.ConsoleSocket = socket + } + if p.Sandbox { + opts.IO = p.io + // UserLog is only useful for sandbox. + opts.UserLog = p.UserLog + } + if err := p.runtime.Create(ctx, r.ID, r.Bundle, opts); err != nil { + return p.runtimeError(err, "OCI runtime create failed") + } + if r.Stdin != "" { + sc, err := fifo.OpenFifo(context.Background(), r.Stdin, syscall.O_WRONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("failed to open stdin fifo %s: %w", r.Stdin, err) + } + p.stdin = sc + p.closers = append(p.closers, sc) + } + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + if socket != nil { + console, err := socket.ReceiveMaster() + if err != nil { + return fmt.Errorf("failed to retrieve console master: %w", err) + } + console, err = p.Platform.CopyConsole(ctx, console, r.Stdin, r.Stdout, r.Stderr, &p.wg) + if err != nil { + return fmt.Errorf("failed to start console copy: %w", err) + } + p.console = console + } else if !hasNoIO(r) { + if err := copyPipes(ctx, p.io, r.Stdin, r.Stdout, r.Stderr, &p.wg); err != nil { + return fmt.Errorf("failed to start io pipe copy: %w", err) + } + } + pid, err := runc.ReadPidFile(pidFile) + if err != nil { + return fmt.Errorf("failed to retrieve OCI runtime container pid: %w", err) + } + p.pid = pid + return nil +} + +// Wait waits for the process to exit. +func (p *Init) Wait() { + <-p.waitBlock +} + +// ID returns the ID of the process. +func (p *Init) ID() string { + return p.id +} + +// Pid returns the PID of the process. +func (p *Init) Pid() int { + return p.pid +} + +// ExitStatus returns the exit status of the process. +func (p *Init) ExitStatus() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.status +} + +// ExitedAt returns the time when the process exited. +func (p *Init) ExitedAt() time.Time { + p.mu.Lock() + defer p.mu.Unlock() + return p.exited +} + +// Status returns the status of the process. +func (p *Init) Status(ctx context.Context) (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + c, err := p.runtime.State(ctx, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return "stopped", nil + } + return "", p.runtimeError(err, "OCI runtime state failed") + } + return p.convertStatus(c.Status), nil +} + +// Start starts the init process. +func (p *Init) Start(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Start(ctx) +} + +func (p *Init) start(ctx context.Context) error { + var cio runc.IO + if !p.Sandbox { + cio = p.io + } + if err := p.runtime.Start(ctx, p.id, cio); err != nil { + return p.runtimeError(err, "OCI runtime start failed") + } + go func() { + status, err := p.runtime.Wait(context.Background(), p.id) + if err != nil { + log.G(ctx).WithError(err).Errorf("Failed to wait for container %q", p.id) + // TODO(random-liu): Handle runsc kill error. + if err := p.killAll(ctx); err != nil { + log.G(ctx).WithError(err).Errorf("Failed to kill container %q", p.id) + } + status = internalErrorCode + } + ExitCh <- Exit{ + Timestamp: time.Now(), + ID: p.id, + Status: status, + } + }() + return nil +} + +// SetExited set the exit stauts of the init process. +func (p *Init) SetExited(status int) { + p.mu.Lock() + defer p.mu.Unlock() + + p.initState.SetExited(status) +} + +func (p *Init) setExited(status int) { + p.exited = time.Now() + p.status = status + p.Platform.ShutdownConsole(context.Background(), p.console) + close(p.waitBlock) +} + +// Delete deletes the init process. +func (p *Init) Delete(ctx context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Delete(ctx) +} + +func (p *Init) delete(ctx context.Context) error { + p.killAll(ctx) + p.wg.Wait() + err := p.runtime.Delete(ctx, p.id, nil) + // ignore errors if a runtime has already deleted the process + // but we still hold metadata and pipes + // + // this is common during a checkpoint, runc will delete the container state + // after a checkpoint and the container will no longer exist within runc + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + err = nil + } else { + err = p.runtimeError(err, "failed to delete task") + } + } + if p.io != nil { + for _, c := range p.closers { + c.Close() + } + p.io.Close() + } + if err2 := mount.UnmountAll(p.Rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("failed to cleanup rootfs mount") + if err == nil { + err = fmt.Errorf("failed rootfs umount: %w", err2) + } + } + return err +} + +// Resize resizes the init processes console. +func (p *Init) Resize(ws console.WinSize) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +func (p *Init) resize(ws console.WinSize) error { + if p.console == nil { + return nil + } + return p.console.Resize(ws) +} + +// Kill kills the init process. +func (p *Init) Kill(ctx context.Context, signal uint32, all bool) error { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Kill(ctx, signal, all) +} + +func (p *Init) kill(context context.Context, signal uint32, all bool) error { + var ( + killErr error + backoff = 100 * time.Millisecond + ) + timeout := 1 * time.Second + for start := time.Now(); time.Now().Sub(start) < timeout; { + c, err := p.runtime.State(context, p.id) + if err != nil { + if strings.Contains(err.Error(), "does not exist") { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + return p.runtimeError(err, "OCI runtime state failed") + } + // For runsc, signal only works when container is running state. + // If the container is not in running state, directly return + // "no such process" + if p.convertStatus(c.Status) == "stopped" { + return fmt.Errorf("no such process: %w", errdefs.ErrNotFound) + } + killErr = p.runtime.Kill(context, p.id, int(signal), &runsc.KillOpts{ + All: all, + }) + if killErr == nil { + return nil + } + time.Sleep(backoff) + backoff *= 2 + } + return p.runtimeError(killErr, "kill timeout") +} + +// KillAll kills all processes belonging to the init process. +func (p *Init) KillAll(context context.Context) error { + p.mu.Lock() + defer p.mu.Unlock() + return p.killAll(context) +} + +func (p *Init) killAll(context context.Context) error { + p.runtime.Kill(context, p.id, int(syscall.SIGKILL), &runsc.KillOpts{ + All: true, + }) + // Ignore error handling for `runsc kill --all` for now. + // * If it doesn't return error, it is good; + // * If it returns error, consider the container has already stopped. + // TODO: Fix `runsc kill --all` error handling. + return nil +} + +// Stdin returns the stdin of the process. +func (p *Init) Stdin() io.Closer { + return p.stdin +} + +// Runtime returns the OCI runtime configured for the init process. +func (p *Init) Runtime() *runsc.Runsc { + return p.runtime +} + +// Exec returns a new child process. +func (p *Init) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + p.mu.Lock() + defer p.mu.Unlock() + + return p.initState.Exec(ctx, path, r) +} + +// exec returns a new exec'd process. +func (p *Init) exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + // process exec request + var spec specs.Process + if err := json.Unmarshal(r.Spec.Value, &spec); err != nil { + return nil, err + } + spec.Terminal = r.Terminal + + e := &execProcess{ + id: r.ID, + path: path, + parent: p, + spec: spec, + stdio: stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }, + waitBlock: make(chan struct{}), + } + e.execState = &execCreatedState{p: e} + return e, nil +} + +// Stdio returns the stdio of the process. +func (p *Init) Stdio() stdio.Stdio { + return p.stdio +} + +func (p *Init) runtimeError(rErr error, msg string) error { + if rErr == nil { + return nil + } + + rMsg, err := getLastRuntimeError(p.runtime) + switch { + case err != nil: + return fmt.Errorf("%s: %w (unable to retrieve OCI runtime error: %v)", msg, rErr, err) + case rMsg == "": + return fmt.Errorf("%s: %w", msg, rErr) + default: + return fmt.Errorf("%s: %s", msg, rMsg) + } +} + +func (p *Init) convertStatus(status string) string { + if status == "created" && !p.Sandbox && p.status == internalErrorCode { + // Treat start failure state for non-root container as stopped. + return "stopped" + } + return status +} + +func withConditionalIO(c stdio.Stdio) runc.IOOpt { + return func(o *runc.IOOption) { + o.OpenStdin = c.Stdin != "" + o.OpenStdout = c.Stdout != "" + o.OpenStderr = c.Stderr != "" + } +} diff --git a/pkg/shim/v1/proc/init_state.go b/pkg/shim/v1/proc/init_state.go new file mode 100644 index 000000000..9233ecc85 --- /dev/null +++ b/pkg/shim/v1/proc/init_state.go @@ -0,0 +1,182 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + + "github.com/containerd/console" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/pkg/process" +) + +type initState interface { + Resize(console.WinSize) error + Start(context.Context) error + Delete(context.Context) error + Exec(context.Context, string, *ExecConfig) (process.Process, error) + Kill(context.Context, uint32, bool) error + SetExited(int) +} + +type createdState struct { + p *Init +} + +func (s *createdState) transition(name string) error { + switch name { + case "running": + s.p.initState = &runningState{p: s.p} + case "stopped": + s.p.initState = &stoppedState{p: s.p} + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *createdState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *createdState) Start(ctx context.Context) error { + if err := s.p.start(ctx); err != nil { + // Containerd doesn't allow deleting container in created state. + // However, for gvisor, a non-root container in created state can + // only go to running state. If the container can't be started, + // it can only stay in created state, and never be deleted. + // To work around that, we treat non-root container in start failure + // state as stopped. + if !s.p.Sandbox { + s.p.io.Close() + s.p.setExited(internalErrorCode) + if err := s.transition("stopped"); err != nil { + panic(err) + } + } + return err + } + return s.transition("running") +} + +func (s *createdState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *createdState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *createdState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *createdState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type runningState struct { + p *Init +} + +func (s *runningState) transition(name string) error { + switch name { + case "stopped": + s.p.initState = &stoppedState{p: s.p} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *runningState) Resize(ws console.WinSize) error { + return s.p.resize(ws) +} + +func (s *runningState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a running process.ss") +} + +func (s *runningState) Delete(ctx context.Context) error { + return fmt.Errorf("cannot delete a running process.ss") +} + +func (s *runningState) Kill(ctx context.Context, sig uint32, all bool) error { + return s.p.kill(ctx, sig, all) +} + +func (s *runningState) SetExited(status int) { + s.p.setExited(status) + + if err := s.transition("stopped"); err != nil { + panic(err) + } +} + +func (s *runningState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return s.p.exec(ctx, path, r) +} + +type stoppedState struct { + p *Init +} + +func (s *stoppedState) transition(name string) error { + switch name { + case "deleted": + s.p.initState = &deletedState{} + default: + return fmt.Errorf("invalid state transition %q to %q", stateName(s), name) + } + return nil +} + +func (s *stoppedState) Resize(ws console.WinSize) error { + return fmt.Errorf("cannot resize a stopped container") +} + +func (s *stoppedState) Start(ctx context.Context) error { + return fmt.Errorf("cannot start a stopped process.ss") +} + +func (s *stoppedState) Delete(ctx context.Context) error { + if err := s.p.delete(ctx); err != nil { + return err + } + return s.transition("deleted") +} + +func (s *stoppedState) Kill(ctx context.Context, sig uint32, all bool) error { + return errdefs.ToGRPCf(errdefs.ErrNotFound, "process.ss %s not found", s.p.id) +} + +func (s *stoppedState) SetExited(status int) { + // no op +} + +func (s *stoppedState) Exec(ctx context.Context, path string, r *ExecConfig) (process.Process, error) { + return nil, fmt.Errorf("cannot exec in a stopped state") +} diff --git a/pkg/shim/v1/proc/io.go b/pkg/shim/v1/proc/io.go new file mode 100644 index 000000000..34d825fb7 --- /dev/null +++ b/pkg/shim/v1/proc/io.go @@ -0,0 +1,162 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "context" + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "syscall" + + "github.com/containerd/containerd/log" + "github.com/containerd/fifo" + runc "github.com/containerd/go-runc" +) + +// TODO(random-liu): This file can be a util. + +var bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, +} + +func copyPipes(ctx context.Context, rio runc.IO, stdin, stdout, stderr string, wg *sync.WaitGroup) error { + var sameFile *countingWriteCloser + for _, i := range []struct { + name string + dest func(wc io.WriteCloser, rc io.Closer) + }{ + { + name: stdout, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stdout(), *p); err != nil { + log.G(ctx).Warn("error copying stdout") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, { + name: stderr, + dest: func(wc io.WriteCloser, rc io.Closer) { + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + if _, err := io.CopyBuffer(wc, rio.Stderr(), *p); err != nil { + log.G(ctx).Warn("error copying stderr") + } + wg.Done() + wc.Close() + if rc != nil { + rc.Close() + } + }() + }, + }, + } { + ok, err := isFifo(i.name) + if err != nil { + return err + } + var ( + fw io.WriteCloser + fr io.Closer + ) + if ok { + if fw, err = fifo.OpenFifo(ctx, i.name, syscall.O_WRONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if fr, err = fifo.OpenFifo(ctx, i.name, syscall.O_RDONLY, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + } else { + if sameFile != nil { + sameFile.count++ + i.dest(sameFile, nil) + continue + } + if fw, err = os.OpenFile(i.name, syscall.O_WRONLY|syscall.O_APPEND, 0); err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", i.name, err) + } + if stdout == stderr { + sameFile = &countingWriteCloser{ + WriteCloser: fw, + count: 1, + } + } + } + i.dest(fw, fr) + } + if stdin == "" { + return nil + } + f, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return fmt.Errorf("gvisor-containerd-shim: opening %s failed: %s", stdin, err) + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + + io.CopyBuffer(rio.Stdin(), f, *p) + rio.Stdin().Close() + f.Close() + }() + return nil +} + +// countingWriteCloser masks io.Closer() until close has been invoked a certain number of times. +type countingWriteCloser struct { + io.WriteCloser + count int64 +} + +func (c *countingWriteCloser) Close() error { + if atomic.AddInt64(&c.count, -1) > 0 { + return nil + } + return c.WriteCloser.Close() +} + +// isFifo checks if a file is a fifo. +// +// If the file does not exist then it returns false. +func isFifo(path string) (bool, error) { + stat, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if stat.Mode()&os.ModeNamedPipe == os.ModeNamedPipe { + return true, nil + } + return false, nil +} diff --git a/pkg/shim/v1/proc/process.go b/pkg/shim/v1/proc/process.go new file mode 100644 index 000000000..d462c3eef --- /dev/null +++ b/pkg/shim/v1/proc/process.go @@ -0,0 +1,37 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "fmt" +) + +// RunscRoot is the path to the root runsc state directory. +const RunscRoot = "/run/containerd/runsc" + +func stateName(v interface{}) string { + switch v.(type) { + case *runningState, *execRunningState: + return "running" + case *createdState, *execCreatedState: + return "created" + case *deletedState: + return "deleted" + case *stoppedState: + return "stopped" + } + panic(fmt.Errorf("invalid state %v", v)) +} diff --git a/pkg/shim/v1/proc/types.go b/pkg/shim/v1/proc/types.go new file mode 100644 index 000000000..2b0df4663 --- /dev/null +++ b/pkg/shim/v1/proc/types.go @@ -0,0 +1,69 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "time" + + runc "github.com/containerd/go-runc" + "github.com/gogo/protobuf/types" +) + +// Mount holds filesystem mount configuration. +type Mount struct { + Type string + Source string + Target string + Options []string +} + +// CreateConfig hold task creation configuration. +type CreateConfig struct { + ID string + Bundle string + Runtime string + Rootfs []Mount + Terminal bool + Stdin string + Stdout string + Stderr string + Options *types.Any +} + +// ExecConfig holds exec creation configuration. +type ExecConfig struct { + ID string + Terminal bool + Stdin string + Stdout string + Stderr string + Spec *types.Any +} + +// Exit is the type of exit events. +type Exit struct { + Timestamp time.Time + ID string + Status int +} + +// ProcessMonitor monitors process exit changes. +type ProcessMonitor interface { + // Subscribe to process exit changes + Subscribe() chan runc.Exit + // Unsubscribe to process exit changes + Unsubscribe(c chan runc.Exit) +} diff --git a/pkg/shim/v1/proc/utils.go b/pkg/shim/v1/proc/utils.go new file mode 100644 index 000000000..716de2f59 --- /dev/null +++ b/pkg/shim/v1/proc/utils.go @@ -0,0 +1,90 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proc + +import ( + "encoding/json" + "io" + "os" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/shim/runsc" +) + +const ( + internalErrorCode = 128 + bufferSize = 32 +) + +// ExitCh is the exit events channel for containers and exec processes +// inside the sandbox. +var ExitCh = make(chan Exit, bufferSize) + +// TODO(mlaventure): move to runc package? +func getLastRuntimeError(r *runsc.Runsc) (string, error) { + if r.Log == "" { + return "", nil + } + + f, err := os.OpenFile(r.Log, os.O_RDONLY, 0400) + if err != nil { + return "", err + } + + var ( + errMsg string + log struct { + Level string + Msg string + Time time.Time + } + ) + + dec := json.NewDecoder(f) + for err = nil; err == nil; { + if err = dec.Decode(&log); err != nil && err != io.EOF { + return "", err + } + if log.Level == "error" { + errMsg = strings.TrimSpace(log.Msg) + } + } + + return errMsg, nil +} + +func copyFile(to, from string) error { + ff, err := os.Open(from) + if err != nil { + return err + } + defer ff.Close() + tt, err := os.Create(to) + if err != nil { + return err + } + defer tt.Close() + + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + _, err = io.CopyBuffer(tt, ff, *p) + return err +} + +func hasNoIO(r *CreateConfig) bool { + return r.Stdin == "" && r.Stdout == "" && r.Stderr == "" +} diff --git a/pkg/shim/v1/shim/BUILD b/pkg/shim/v1/shim/BUILD new file mode 100644 index 000000000..05c595bc9 --- /dev/null +++ b/pkg/shim/v1/shim/BUILD @@ -0,0 +1,40 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "shim", + srcs = [ + "api.go", + "platform.go", + "service.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + ], +) diff --git a/pkg/shim/v1/shim/api.go b/pkg/shim/v1/shim/api.go new file mode 100644 index 000000000..5dd8ff172 --- /dev/null +++ b/pkg/shim/v1/shim/api.go @@ -0,0 +1,28 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskCreate = events.TaskCreate +type TaskStart = events.TaskStart +type TaskOOM = events.TaskOOM +type TaskExit = events.TaskExit +type TaskDelete = events.TaskDelete +type TaskExecAdded = events.TaskExecAdded +type TaskExecStarted = events.TaskExecStarted diff --git a/pkg/shim/v1/shim/platform.go b/pkg/shim/v1/shim/platform.go new file mode 100644 index 000000000..f590f80ef --- /dev/null +++ b/pkg/shim/v1/shim/platform.go @@ -0,0 +1,106 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(ctx, stdin, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *Service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/shim/v1/shim/service.go b/pkg/shim/v1/shim/service.go new file mode 100644 index 000000000..84a810cb2 --- /dev/null +++ b/pkg/shim/v1/shim/service.go @@ -0,0 +1,573 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shim + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/containerd/console" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +// Config contains shim specific configuration. +type Config struct { + Path string + Namespace string + WorkDir string + RuntimeRoot string + RunscConfig map[string]string +} + +// NewService returns a new shim service that can be used via GRPC. +func NewService(config Config, publisher events.Publisher) (*Service, error) { + if config.Namespace == "" { + return nil, fmt.Errorf("shim namespace cannot be empty") + } + ctx := namespaces.WithNamespace(context.Background(), config.Namespace) + s := &Service{ + config: config, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + } + go s.processExits() + if err := s.initPlatform(); err != nil { + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// Service is the shim implementation of a remote shim over GRPC. +type Service struct { + mu sync.Mutex + + config Config + context context.Context + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + ec chan proc.Exit + + // Filled by Create() + id string + bundle string +} + +// Create creates a new initial process and container with the underlying OCI runtime. +func (s *Service) Create(ctx context.Context, r *shim.CreateTaskRequest) (_ *shim.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: r.Runtime, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + defer func() { + if err != nil { + if err2 := mount.UnmountAll(rootfs, 0); err2 != nil { + log.G(ctx).WithError(err2).Warn("Failed to cleanup rootfs mount") + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + s.config.Path, + s.config.WorkDir, + s.config.RuntimeRoot, + s.config.Namespace, + s.config.RunscConfig, + s.platform, + config, + ) + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + pid := process.Pid() + s.processes[r.ID] = process + return &shim.CreateTaskResponse{ + Pid: uint32(pid), + }, nil +} + +// Start starts a process. +func (s *Service) Start(ctx context.Context, r *shim.StartRequest) (*shim.StartResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + return &shim.StartResponse{ + ID: p.ID(), + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *Service) Delete(ctx context.Context, r *types.Empty) (*shim.DeleteResponse, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, s.id) + s.mu.Unlock() + s.platform.Close() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// DeleteProcess deletes an exec'd process. +func (s *Service) DeleteProcess(ctx context.Context, r *shim.DeleteProcessRequest) (*shim.DeleteResponse, error) { + if r.ID == s.id { + return nil, status.Errorf(codes.InvalidArgument, "cannot delete init process with DeleteProcess") + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + s.mu.Lock() + delete(s.processes, r.ID) + s.mu.Unlock() + return &shim.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *Service) Exec(ctx context.Context, r *shim.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + + if p := s.processes[r.ID]; p != nil { + s.mu.Unlock() + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ID) + } + + p := s.processes[s.id] + s.mu.Unlock() + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + + process, err := p.(*proc.Init).Exec(ctx, s.config.Path, &proc.ExecConfig{ + ID: r.ID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resises the terminal of a process. +func (s *Service) ResizePty(ctx context.Context, r *shim.ResizePtyRequest) (*types.Empty, error) { + if r.ID == "" { + return nil, errdefs.ToGRPCf(errdefs.ErrInvalidArgument, "id not provided") + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *Service) State(ctx context.Context, r *shim.StateRequest) (*shim.StateResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &shim.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause pauses the container. +func (s *Service) Pause(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume resumes the container. +func (s *Service) Resume(ctx context.Context, r *types.Empty) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill kills a process with the provided signal. +func (s *Service) Kill(ctx context.Context, r *shim.KillRequest) (*types.Empty, error) { + if r.ID == "" { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil + } + + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// ListPids returns all pids inside the container. +func (s *Service) ListPids(ctx context.Context, r *shim.ListPidsRequest) (*shim.ListPidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &shim.ListPidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *Service) CloseIO(ctx context.Context, r *shim.CloseIORequest) (*types.Empty, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *Service) Checkpoint(ctx context.Context, r *shim.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// ShimInfo returns shim information such as the shim's pid. +func (s *Service) ShimInfo(ctx context.Context, r *types.Empty) (*shim.ShimInfoResponse, error) { + return &shim.ShimInfoResponse{ + ShimPid: uint32(os.Getpid()), + }, nil +} + +// Update updates a running container. +func (s *Service) Update(ctx context.Context, r *shim.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *Service) Wait(ctx context.Context, r *shim.WaitRequest) (*shim.WaitResponse, error) { + p, err := s.getExecProcess(r.ID) + if err != nil { + return nil, err + } + p.Wait() + + return &shim.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *Service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *Service) allProcesses() []process.Process { + s.mu.Lock() + defer s.mu.Unlock() + + res := make([]process.Process, 0, len(s.processes)) + for _, p := range s.processes { + res = append(res, p) + } + return res +} + +func (s *Service) checkProcesses(e proc.Exit) { + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *Service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + p, err := s.getInitProcess() + if err != nil { + return nil, err + } + + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *Service) forward(publisher events.Publisher) { + for e := range s.events { + if err := publisher.Publish(s.context, getTopic(s.context, e), e); err != nil { + log.G(s.context).WithError(err).Error("post event") + } + } +} + +// getInitProcess returns the init process. +func (s *Service) getInitProcess() (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[s.id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + return p, nil +} + +// getExecProcess returns the given exec process. +func (s *Service) getExecProcess(id string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + p := s.processes[id] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process %s does not exist", id) + } + return p, nil +} + +func getTopic(ctx context.Context, e interface{}) string { + switch e.(type) { + case *TaskCreate: + return runtime.TaskCreateEventTopic + case *TaskStart: + return runtime.TaskStartEventTopic + case *TaskOOM: + return runtime.TaskOOMEventTopic + case *TaskExit: + return runtime.TaskExitEventTopic + case *TaskDelete: + return runtime.TaskDeleteEventTopic + case *TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, runtimeRoot, namespace string, config map[string]string, platform stdio.Platform, r *proc.CreateConfig) (*proc.Init, error) { + var options runctypes.CreateOptions + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + options = *v.(*runctypes.CreateOptions) + } + + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + + runsc.FormatLogPath(r.ID, config) + rootfs := filepath.Join(path, "rootfs") + runtime := proc.NewRunsc(runtimeRoot, path, namespace, r.Runtime, config) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = utils.IsSandbox(spec) + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v1/utils/BUILD b/pkg/shim/v1/utils/BUILD new file mode 100644 index 000000000..54a0aabb7 --- /dev/null +++ b/pkg/shim/v1/utils/BUILD @@ -0,0 +1,27 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "utils", + srcs = [ + "annotations.go", + "utils.go", + "volumes.go", + ], + visibility = [ + "//pkg/shim:__subpackages__", + "//shim:__subpackages__", + ], + deps = [ + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "utils_test", + size = "small", + srcs = ["volumes_test.go"], + library = ":utils", + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], +) diff --git a/pkg/shim/v1/utils/annotations.go b/pkg/shim/v1/utils/annotations.go new file mode 100644 index 000000000..1e9d3f365 --- /dev/null +++ b/pkg/shim/v1/utils/annotations.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +// Annotations from the CRI annotations package. +// +// These are vendor due to import conflicts. +const ( + sandboxLogDirAnnotation = "io.kubernetes.cri.sandbox-log-directory" + containerTypeAnnotation = "io.kubernetes.cri.container-type" + containerTypeSandbox = "sandbox" + containerTypeContainer = "container" +) diff --git a/pkg/shim/v1/utils/utils.go b/pkg/shim/v1/utils/utils.go new file mode 100644 index 000000000..07e346654 --- /dev/null +++ b/pkg/shim/v1/utils/utils.go @@ -0,0 +1,56 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +// ReadSpec reads OCI spec from the bundle directory. +func ReadSpec(bundle string) (*specs.Spec, error) { + f, err := os.Open(filepath.Join(bundle, "config.json")) + if err != nil { + return nil, err + } + b, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + return nil, err + } + return &spec, nil +} + +// IsSandbox checks whether a container is a sandbox container. +func IsSandbox(spec *specs.Spec) bool { + t, ok := spec.Annotations[containerTypeAnnotation] + return !ok || t == containerTypeSandbox +} + +// UserLogPath gets user log path from OCI annotation. +func UserLogPath(spec *specs.Spec) string { + sandboxLogDir := spec.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "" + } + return filepath.Join(sandboxLogDir, "gvisor.log") +} diff --git a/pkg/shim/v1/utils/volumes.go b/pkg/shim/v1/utils/volumes.go new file mode 100644 index 000000000..52a428179 --- /dev/null +++ b/pkg/shim/v1/utils/volumes.go @@ -0,0 +1,155 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +const volumeKeyPrefix = "dev.gvisor.spec.mount." + +var kubeletPodsDir = "/var/lib/kubelet/pods" + +// volumeName gets volume name from volume annotation key, example: +// dev.gvisor.spec.mount.NAME.share +func volumeName(k string) string { + return strings.SplitN(strings.TrimPrefix(k, volumeKeyPrefix), ".", 2)[0] +} + +// volumeFieldName gets volume field name from volume annotation key, example: +// `type` is the field of dev.gvisor.spec.mount.NAME.type +func volumeFieldName(k string) string { + parts := strings.Split(strings.TrimPrefix(k, volumeKeyPrefix), ".") + return parts[len(parts)-1] +} + +// podUID gets pod UID from the pod log path. +func podUID(s *specs.Spec) (string, error) { + sandboxLogDir := s.Annotations[sandboxLogDirAnnotation] + if sandboxLogDir == "" { + return "", fmt.Errorf("no sandbox log path annotation") + } + fields := strings.Split(filepath.Base(sandboxLogDir), "_") + switch len(fields) { + case 1: // This is the old CRI logging path. + return fields[0], nil + case 3: // This is the new CRI logging path. + return fields[2], nil + } + return "", fmt.Errorf("unexpected sandbox log path %q", sandboxLogDir) +} + +// isVolumeKey checks whether an annotation key is for volume. +func isVolumeKey(k string) bool { + return strings.HasPrefix(k, volumeKeyPrefix) +} + +// volumeSourceKey constructs the annotation key for volume source. +func volumeSourceKey(volume string) string { + return volumeKeyPrefix + volume + ".source" +} + +// volumePath searches the volume path in the kubelet pod directory. +func volumePath(volume, uid string) (string, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/%s/volumes/*/%s", kubeletPodsDir, uid, volume) + dirs, err := filepath.Glob(volumeSearchPath) + if err != nil { + return "", err + } + if len(dirs) != 1 { + return "", fmt.Errorf("unexpected matched volume list %v", dirs) + } + return dirs[0], nil +} + +// isVolumePath checks whether a string is the volume path. +func isVolumePath(volume, path string) (bool, error) { + // TODO: Support subpath when gvisor supports pod volume bind mount. + volumeSearchPath := fmt.Sprintf("%s/*/volumes/*/%s", kubeletPodsDir, volume) + return filepath.Match(volumeSearchPath, path) +} + +// UpdateVolumeAnnotations add necessary OCI annotations for gvisor +// volume optimization. +func UpdateVolumeAnnotations(bundle string, s *specs.Spec) error { + var ( + uid string + err error + ) + if IsSandbox(s) { + uid, err = podUID(s) + if err != nil { + // Skip if we can't get pod UID, because this doesn't work + // for containerd 1.1. + return nil + } + } + var updated bool + for k, v := range s.Annotations { + if !isVolumeKey(k) { + continue + } + if volumeFieldName(k) != "type" { + continue + } + volume := volumeName(k) + if uid != "" { + // This is a sandbox. + path, err := volumePath(volume, uid) + if err != nil { + return fmt.Errorf("get volume path for %q: %w", volume, err) + } + s.Annotations[volumeSourceKey(volume)] = path + updated = true + } else { + // This is a container. + for i := range s.Mounts { + // An error is returned for sandbox if source + // annotation is not successfully applied, so + // it is guaranteed that the source annotation + // for sandbox has already been successfully + // applied at this point. + // + // The volume name is unique inside a pod, so + // matching without podUID is fine here. + // + // TODO: Pass podUID down to shim for containers to do + // more accurate matching. + if yes, _ := isVolumePath(volume, s.Mounts[i].Source); yes { + // gVisor requires the container mount type to match + // sandbox mount type. + s.Mounts[i].Type = v + updated = true + } + } + } + } + if !updated { + return nil + } + // Update bundle. + b, err := json.Marshal(s) + if err != nil { + return err + } + return ioutil.WriteFile(filepath.Join(bundle, "config.json"), b, 0666) +} diff --git a/pkg/shim/v1/utils/volumes_test.go b/pkg/shim/v1/utils/volumes_test.go new file mode 100644 index 000000000..3e02c6151 --- /dev/null +++ b/pkg/shim/v1/utils/volumes_test.go @@ -0,0 +1,308 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" +) + +func TestUpdateVolumeAnnotations(t *testing.T) { + dir, err := ioutil.TempDir("", "test-update-volume-annotations") + if err != nil { + t.Fatalf("create tempdir: %v", err) + } + defer os.RemoveAll(dir) + kubeletPodsDir = dir + + const ( + testPodUID = "testuid" + testVolumeName = "testvolume" + testLogDirPath = "/var/log/pods/testns_testname_" + testPodUID + testLegacyLogDirPath = "/var/log/pods/" + testPodUID + ) + testVolumePath := fmt.Sprintf("%s/%s/volumes/kubernetes.io~empty-dir/%s", dir, testPodUID, testVolumeName) + + if err := os.MkdirAll(testVolumePath, 0755); err != nil { + t.Fatalf("Create test volume: %v", err) + } + + for _, test := range []struct { + desc string + spec *specs.Spec + expected *specs.Spec + expectErr bool + expectUpdate bool + }{ + { + desc: "volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "volume annotations for sandbox with legacy log path", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLegacyLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + "dev.gvisor.spec.mount." + testVolumeName + ".source": testVolumePath, + }, + }, + expectUpdate: true, + }, + { + desc: "tmpfs: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "tmpfs", + Source: testVolumePath, + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "bind: volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: testVolumePath, + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "container", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "bind", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expectUpdate: true, + }, + { + desc: "should not return error without pod log directory", + spec: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount." + testVolumeName + ".share": "pod", + "dev.gvisor.spec.mount." + testVolumeName + ".type": "tmpfs", + "dev.gvisor.spec.mount." + testVolumeName + ".options": "ro", + }, + }, + }, + { + desc: "should return error if volume path does not exist", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + "dev.gvisor.spec.mount.notexist.share": "pod", + "dev.gvisor.spec.mount.notexist.type": "tmpfs", + "dev.gvisor.spec.mount.notexist.options": "ro", + }, + }, + expectErr: true, + }, + { + desc: "no volume annotations for sandbox", + spec: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + expected: &specs.Spec{ + Annotations: map[string]string{ + sandboxLogDirAnnotation: testLogDirPath, + containerTypeAnnotation: containerTypeSandbox, + }, + }, + }, + { + desc: "no volume annotations for container", + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + expected: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/test", + Type: "bind", + Source: "/test", + Options: []string{"ro"}, + }, + { + Destination: "/random", + Type: "bind", + Source: "/random", + Options: []string{"ro"}, + }, + }, + Annotations: map[string]string{ + containerTypeAnnotation: containerTypeContainer, + }, + }, + }, + } { + t.Run(test.desc, func(t *testing.T) { + bundle, err := ioutil.TempDir(dir, "test-bundle") + if err != nil { + t.Fatalf("Create test bundle: %v", err) + } + err = UpdateVolumeAnnotations(bundle, test.spec) + if test.expectErr { + if err == nil { + t.Fatal("Expected error, but got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(test.expected, test.spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, test.spec) + } + if test.expectUpdate { + b, err := ioutil.ReadFile(filepath.Join(bundle, "config.json")) + if err != nil { + t.Fatalf("Read spec from bundle: %v", err) + } + var spec specs.Spec + if err := json.Unmarshal(b, &spec); err != nil { + t.Fatalf("Unmarshal spec: %v", err) + } + if !reflect.DeepEqual(test.expected, &spec) { + t.Fatalf("Expected %+v, got %+v", test.expected, &spec) + } + } + }) + } +} diff --git a/pkg/shim/v2/BUILD b/pkg/shim/v2/BUILD new file mode 100644 index 000000000..7e0a114a0 --- /dev/null +++ b/pkg/shim/v2/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "v2", + srcs = [ + "api.go", + "epoll.go", + "service.go", + "service_linux.go", + ], + visibility = ["//shim:__subpackages__"], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/proc", + "//pkg/shim/v1/utils", + "//pkg/shim/v2/options", + "//pkg/shim/v2/runtimeoptions", + "//runsc/specutils", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_cgroups//:go_default_library", + "@com_github_containerd_console//:go_default_library", + "@com_github_containerd_containerd//api/events:go_default_library", + "@com_github_containerd_containerd//api/types/task:go_default_library", + "@com_github_containerd_containerd//errdefs:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//log:go_default_library", + "@com_github_containerd_containerd//mount:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//pkg/process:go_default_library", + "@com_github_containerd_containerd//pkg/stdio:go_default_library", + "@com_github_containerd_containerd//runtime:go_default_library", + "@com_github_containerd_containerd//runtime/linux/runctypes:go_default_library", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + "@com_github_containerd_containerd//runtime/v2/task:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_fifo//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/pkg/shim/v2/api.go b/pkg/shim/v2/api.go new file mode 100644 index 000000000..dbe5c59f6 --- /dev/null +++ b/pkg/shim/v2/api.go @@ -0,0 +1,22 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "github.com/containerd/containerd/api/events" +) + +type TaskOOM = events.TaskOOM diff --git a/pkg/shim/v2/epoll.go b/pkg/shim/v2/epoll.go new file mode 100644 index 000000000..41232cca8 --- /dev/null +++ b/pkg/shim/v2/epoll.go @@ -0,0 +1,129 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "sync" + + "github.com/containerd/cgroups" + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/runtime" + "golang.org/x/sys/unix" +) + +func newOOMEpoller(publisher events.Publisher) (*epoller, error) { + fd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC) + if err != nil { + return nil, err + } + return &epoller{ + fd: fd, + publisher: publisher, + set: make(map[uintptr]*item), + }, nil +} + +type epoller struct { + mu sync.Mutex + + fd int + publisher events.Publisher + set map[uintptr]*item +} + +type item struct { + id string + cg cgroups.Cgroup +} + +func (e *epoller) Close() error { + return unix.Close(e.fd) +} + +func (e *epoller) run(ctx context.Context) { + var events [128]unix.EpollEvent + for { + select { + case <-ctx.Done(): + e.Close() + return + default: + n, err := unix.EpollWait(e.fd, events[:], -1) + if err != nil { + if err == unix.EINTR || err == unix.EAGAIN { + continue + } + // Should not happen. + panic(fmt.Errorf("cgroups: epoll wait: %w", err)) + } + for i := 0; i < n; i++ { + e.process(ctx, uintptr(events[i].Fd)) + } + } + } +} + +func (e *epoller) add(id string, cg cgroups.Cgroup) error { + e.mu.Lock() + defer e.mu.Unlock() + fd, err := cg.OOMEventFD() + if err != nil { + return err + } + e.set[fd] = &item{ + id: id, + cg: cg, + } + event := unix.EpollEvent{ + Fd: int32(fd), + Events: unix.EPOLLHUP | unix.EPOLLIN | unix.EPOLLERR, + } + return unix.EpollCtl(e.fd, unix.EPOLL_CTL_ADD, int(fd), &event) +} + +func (e *epoller) process(ctx context.Context, fd uintptr) { + flush(fd) + e.mu.Lock() + i, ok := e.set[fd] + if !ok { + e.mu.Unlock() + return + } + e.mu.Unlock() + if i.cg.State() == cgroups.Deleted { + e.mu.Lock() + delete(e.set, fd) + e.mu.Unlock() + unix.Close(int(fd)) + return + } + if err := e.publisher.Publish(ctx, runtime.TaskOOMEventTopic, &TaskOOM{ + ContainerID: i.id, + }); err != nil { + // Should not happen. + panic(fmt.Errorf("publish OOM event: %w", err)) + } +} + +func flush(fd uintptr) error { + var buf [8]byte + _, err := unix.Read(int(fd), buf[:]) + return err +} diff --git a/pkg/shim/v2/options/BUILD b/pkg/shim/v2/options/BUILD new file mode 100644 index 000000000..ca212e874 --- /dev/null +++ b/pkg/shim/v2/options/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "options", + srcs = [ + "options.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/shim/v2/options/options.go b/pkg/shim/v2/options/options.go new file mode 100644 index 000000000..de09f2f79 --- /dev/null +++ b/pkg/shim/v2/options/options.go @@ -0,0 +1,33 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +const OptionType = "io.containerd.runsc.v1.options" + +// Options is runtime options for io.containerd.runsc.v1. +type Options struct { + // ShimCgroup is the cgroup the shim should be in. + ShimCgroup string `toml:"shim_cgroup"` + // IoUid is the I/O's pipes uid. + IoUid uint32 `toml:"io_uid"` + // IoUid is the I/O's pipes gid. + IoGid uint32 `toml:"io_gid"` + // BinaryName is the binary name of the runsc binary. + BinaryName string `toml:"binary_name"` + // Root is the runsc root directory. + Root string `toml:"root"` + // RunscConfig is a key/value map of all runsc flags. + RunscConfig map[string]string `toml:"runsc_config"` +} diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD new file mode 100644 index 000000000..01716034c --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -0,0 +1,20 @@ +load("//tools:defs.bzl", "go_library", "proto_library") + +package(licenses = ["notice"]) + +proto_library( + name = "api", + srcs = [ + "runtimeoptions.proto", + ], +) + +go_library( + name = "runtimeoptions", + srcs = ["runtimeoptions.go"], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = [ + "//pkg/shim/v2/runtimeoptions:api_go_proto", + "@com_github_gogo_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go new file mode 100644 index 000000000..1c1a0c5d1 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -0,0 +1,27 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + proto "github.com/gogo/protobuf/proto" + pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" +) + +type Options = pb.Options + +func init() { + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto new file mode 100644 index 000000000..edb19020a --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package runtimeoptions; + +// This is a version of the runtimeoptions CRI API that is vendored. +// +// Imported the full CRI package is a nightmare. +message Options { + string type_url = 1; + string config_path = 2; +} diff --git a/pkg/shim/v2/service.go b/pkg/shim/v2/service.go new file mode 100644 index 000000000..1534152fc --- /dev/null +++ b/pkg/shim/v2/service.go @@ -0,0 +1,824 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package v2 + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "sync" + "syscall" + "time" + + "github.com/BurntSushi/toml" + "github.com/containerd/cgroups" + "github.com/containerd/console" + "github.com/containerd/containerd/api/events" + "github.com/containerd/containerd/api/types/task" + "github.com/containerd/containerd/errdefs" + "github.com/containerd/containerd/log" + "github.com/containerd/containerd/mount" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/pkg/process" + "github.com/containerd/containerd/pkg/stdio" + "github.com/containerd/containerd/runtime" + "github.com/containerd/containerd/runtime/linux/runctypes" + "github.com/containerd/containerd/runtime/v2/shim" + taskAPI "github.com/containerd/containerd/runtime/v2/task" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/proc" + "gvisor.dev/gvisor/pkg/shim/v1/utils" + "gvisor.dev/gvisor/pkg/shim/v2/options" + "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions" + "gvisor.dev/gvisor/runsc/specutils" +) + +var ( + empty = &types.Empty{} + bufPool = sync.Pool{ + New: func() interface{} { + buffer := make([]byte, 32<<10) + return &buffer + }, + } +) + +var _ = (taskAPI.TaskService)(&service{}) + +// configFile is the default config file name. For containerd 1.2, +// we assume that a config.toml should exist in the runtime root. +const configFile = "config.toml" + +// New returns a new shim service that can be used via GRPC. +func New(ctx context.Context, id string, publisher shim.Publisher, cancel func()) (shim.Shim, error) { + ep, err := newOOMEpoller(publisher) + if err != nil { + return nil, err + } + go ep.run(ctx) + s := &service{ + id: id, + context: ctx, + processes: make(map[string]process.Process), + events: make(chan interface{}, 128), + ec: proc.ExitCh, + oomPoller: ep, + cancel: cancel, + } + go s.processExits() + runsc.Monitor = reaper.Default + if err := s.initPlatform(); err != nil { + cancel() + return nil, fmt.Errorf("failed to initialized platform behavior: %w", err) + } + go s.forward(publisher) + return s, nil +} + +// service is the shim implementation of a remote shim over GRPC. +type service struct { + mu sync.Mutex + + context context.Context + task process.Process + processes map[string]process.Process + events chan interface{} + platform stdio.Platform + opts options.Options + ec chan proc.Exit + oomPoller *epoller + + id string + bundle string + cancel func() +} + +func newCommand(ctx context.Context, containerdBinary, containerdAddress string) (*exec.Cmd, error) { + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + self, err := os.Executable() + if err != nil { + return nil, err + } + cwd, err := os.Getwd() + if err != nil { + return nil, err + } + args := []string{ + "-namespace", ns, + "-address", containerdAddress, + "-publish-binary", containerdBinary, + } + cmd := exec.Command(self, args...) + cmd.Dir = cwd + cmd.Env = append(os.Environ(), "GOMAXPROCS=2") + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + return cmd, nil +} + +func (s *service) StartShim(ctx context.Context, id, containerdBinary, containerdAddress, containerdTTRPCAddress string) (string, error) { + cmd, err := newCommand(ctx, containerdBinary, containerdAddress) + if err != nil { + return "", err + } + address, err := shim.SocketAddress(ctx, id) + if err != nil { + return "", err + } + socket, err := shim.NewSocket(address) + if err != nil { + return "", err + } + defer socket.Close() + f, err := socket.File() + if err != nil { + return "", err + } + defer f.Close() + + cmd.ExtraFiles = append(cmd.ExtraFiles, f) + + if err := cmd.Start(); err != nil { + return "", err + } + defer func() { + if err != nil { + cmd.Process.Kill() + } + }() + // make sure to wait after start + go cmd.Wait() + if err := shim.WritePidFile("shim.pid", cmd.Process.Pid); err != nil { + return "", err + } + if err := shim.WriteAddress("address", address); err != nil { + return "", err + } + if err := shim.SetScore(cmd.Process.Pid); err != nil { + return "", fmt.Errorf("failed to set OOM Score on shim: %w", err) + } + return address, nil +} + +func (s *service) Cleanup(ctx context.Context) (*taskAPI.DeleteResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + r := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + if err := r.Delete(ctx, s.id, &runsc.DeleteOpts{ + Force: true, + }); err != nil { + log.L.Printf("failed to remove runc container: %v", err) + } + if err := mount.UnmountAll(filepath.Join(path, "rootfs"), 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + return &taskAPI.DeleteResponse{ + ExitedAt: time.Now(), + ExitStatus: 128 + uint32(unix.SIGKILL), + }, nil +} + +func (s *service) readRuntime(path string) (string, error) { + data, err := ioutil.ReadFile(filepath.Join(path, "runtime")) + if err != nil { + return "", err + } + return string(data), nil +} + +func (s *service) writeRuntime(path, runtime string) error { + return ioutil.WriteFile(filepath.Join(path, "runtime"), []byte(runtime), 0600) +} + +// Create creates a new initial process and container with the underlying OCI +// runtime. +func (s *service) Create(ctx context.Context, r *taskAPI.CreateTaskRequest) (_ *taskAPI.CreateTaskResponse, err error) { + s.mu.Lock() + defer s.mu.Unlock() + + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, fmt.Errorf("create namespace: %w", err) + } + + // Read from root for now. + var opts options.Options + if r.Options != nil { + v, err := typeurl.UnmarshalAny(r.Options) + if err != nil { + return nil, err + } + var path string + switch o := v.(type) { + case *runctypes.CreateOptions: // containerd 1.2.x + opts.IoUid = o.IoUid + opts.IoGid = o.IoGid + opts.ShimCgroup = o.ShimCgroup + case *runctypes.RuncOptions: // containerd 1.2.x + root := proc.RunscRoot + if o.RuntimeRoot != "" { + root = o.RuntimeRoot + } + + opts.BinaryName = o.Runtime + + path = filepath.Join(root, configFile) + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("stat config file %q: %w", path, err) + } + // A config file in runtime root is not required. + path = "" + } + case *runtimeoptions.Options: // containerd 1.3.x+ + if o.ConfigPath == "" { + break + } + if o.TypeUrl != options.OptionType { + return nil, fmt.Errorf("unsupported option type %q", o.TypeUrl) + } + path = o.ConfigPath + default: + return nil, fmt.Errorf("unsupported option type %q", r.Options.TypeUrl) + } + if path != "" { + if _, err = toml.DecodeFile(path, &opts); err != nil { + return nil, fmt.Errorf("decode config file %q: %w", path, err) + } + } + } + + var mounts []proc.Mount + for _, m := range r.Rootfs { + mounts = append(mounts, proc.Mount{ + Type: m.Type, + Source: m.Source, + Target: m.Target, + Options: m.Options, + }) + } + + rootfs := filepath.Join(r.Bundle, "rootfs") + if err := os.Mkdir(rootfs, 0711); err != nil && !os.IsExist(err) { + return nil, err + } + + config := &proc.CreateConfig{ + ID: r.ID, + Bundle: r.Bundle, + Runtime: opts.BinaryName, + Rootfs: mounts, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Options: r.Options, + } + if err := s.writeRuntime(r.Bundle, opts.BinaryName); err != nil { + return nil, err + } + defer func() { + if err != nil { + if err := mount.UnmountAll(rootfs, 0); err != nil { + log.L.Printf("failed to cleanup rootfs mount: %v", err) + } + } + }() + for _, rm := range mounts { + m := &mount.Mount{ + Type: rm.Type, + Source: rm.Source, + Options: rm.Options, + } + if err := m.Mount(rootfs); err != nil { + return nil, fmt.Errorf("failed to mount rootfs component %v: %w", m, err) + } + } + process, err := newInit( + ctx, + r.Bundle, + filepath.Join(r.Bundle, "work"), + ns, + s.platform, + config, + &opts, + rootfs, + ) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + if err := process.Create(ctx, config); err != nil { + return nil, errdefs.ToGRPC(err) + } + // Save the main task id and bundle to the shim for additional + // requests. + s.id = r.ID + s.bundle = r.Bundle + + // Set up OOM notification on the sandbox's cgroup. This is done on + // sandbox create since the sandbox process will be created here. + pid := process.Pid() + if pid > 0 { + cg, err := cgroups.Load(cgroups.V1, cgroups.PidPath(pid)) + if err != nil { + return nil, fmt.Errorf("loading cgroup for %d: %w", pid, err) + } + if err := s.oomPoller.add(s.id, cg); err != nil { + return nil, fmt.Errorf("add cg to OOM monitor: %w", err) + } + } + s.task = process + s.opts = opts + return &taskAPI.CreateTaskResponse{ + Pid: uint32(process.Pid()), + }, nil + +} + +// Start starts a process. +func (s *service) Start(ctx context.Context, r *taskAPI.StartRequest) (*taskAPI.StartResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if err := p.Start(ctx); err != nil { + return nil, err + } + // TODO: Set the cgroup and oom notifications on restore. + // https://github.com/google/gvisor-containerd-shim/issues/58 + return &taskAPI.StartResponse{ + Pid: uint32(p.Pid()), + }, nil +} + +// Delete deletes the initial process and container. +func (s *service) Delete(ctx context.Context, r *taskAPI.DeleteRequest) (*taskAPI.DeleteResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Delete(ctx); err != nil { + return nil, err + } + isTask := r.ExecID == "" + if !isTask { + s.mu.Lock() + delete(s.processes, r.ExecID) + s.mu.Unlock() + } + if isTask && s.platform != nil { + s.platform.Close() + } + return &taskAPI.DeleteResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + Pid: uint32(p.Pid()), + }, nil +} + +// Exec spawns an additional process inside the container. +func (s *service) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*types.Empty, error) { + s.mu.Lock() + p := s.processes[r.ExecID] + s.mu.Unlock() + if p != nil { + return nil, errdefs.ToGRPCf(errdefs.ErrAlreadyExists, "id %s", r.ExecID) + } + p = s.task + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + process, err := p.(*proc.Init).Exec(ctx, s.bundle, &proc.ExecConfig{ + ID: r.ExecID, + Terminal: r.Terminal, + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Spec: r.Spec, + }) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + s.mu.Lock() + s.processes[r.ExecID] = process + s.mu.Unlock() + return empty, nil +} + +// ResizePty resizes the terminal of a process. +func (s *service) ResizePty(ctx context.Context, r *taskAPI.ResizePtyRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + ws := console.WinSize{ + Width: uint16(r.Width), + Height: uint16(r.Height), + } + if err := p.Resize(ws); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// State returns runtime state information for a process. +func (s *service) State(ctx context.Context, r *taskAPI.StateRequest) (*taskAPI.StateResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + st, err := p.Status(ctx) + if err != nil { + return nil, err + } + status := task.StatusUnknown + switch st { + case "created": + status = task.StatusCreated + case "running": + status = task.StatusRunning + case "stopped": + status = task.StatusStopped + } + sio := p.Stdio() + return &taskAPI.StateResponse{ + ID: p.ID(), + Bundle: s.bundle, + Pid: uint32(p.Pid()), + Status: status, + Stdin: sio.Stdin, + Stdout: sio.Stdout, + Stderr: sio.Stderr, + Terminal: sio.Terminal, + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +// Pause the container. +func (s *service) Pause(ctx context.Context, r *taskAPI.PauseRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Resume the container. +func (s *service) Resume(ctx context.Context, r *taskAPI.ResumeRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Kill a process with the provided signal. +func (s *service) Kill(ctx context.Context, r *taskAPI.KillRequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + if err := p.Kill(ctx, r.Signal, r.All); err != nil { + return nil, errdefs.ToGRPC(err) + } + return empty, nil +} + +// Pids returns all pids inside the container. +func (s *service) Pids(ctx context.Context, r *taskAPI.PidsRequest) (*taskAPI.PidsResponse, error) { + pids, err := s.getContainerPids(ctx, r.ID) + if err != nil { + return nil, errdefs.ToGRPC(err) + } + var processes []*task.ProcessInfo + for _, pid := range pids { + pInfo := task.ProcessInfo{ + Pid: pid, + } + for _, p := range s.processes { + if p.Pid() == int(pid) { + d := &runctypes.ProcessDetails{ + ExecID: p.ID(), + } + a, err := typeurl.MarshalAny(d) + if err != nil { + return nil, fmt.Errorf("failed to marshal process %d info: %w", pid, err) + } + pInfo.Info = a + break + } + } + processes = append(processes, &pInfo) + } + return &taskAPI.PidsResponse{ + Processes: processes, + }, nil +} + +// CloseIO closes the I/O context of a process. +func (s *service) CloseIO(ctx context.Context, r *taskAPI.CloseIORequest) (*types.Empty, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if stdin := p.Stdin(); stdin != nil { + if err := stdin.Close(); err != nil { + return nil, fmt.Errorf("close stdin: %w", err) + } + } + return empty, nil +} + +// Checkpoint checkpoints the container. +func (s *service) Checkpoint(ctx context.Context, r *taskAPI.CheckpointTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Connect returns shim information such as the shim's pid. +func (s *service) Connect(ctx context.Context, r *taskAPI.ConnectRequest) (*taskAPI.ConnectResponse, error) { + var pid int + if s.task != nil { + pid = s.task.Pid() + } + return &taskAPI.ConnectResponse{ + ShimPid: uint32(os.Getpid()), + TaskPid: uint32(pid), + }, nil +} + +func (s *service) Shutdown(ctx context.Context, r *taskAPI.ShutdownRequest) (*types.Empty, error) { + s.cancel() + os.Exit(0) + return empty, nil +} + +func (s *service) Stats(ctx context.Context, r *taskAPI.StatsRequest) (*taskAPI.StatsResponse, error) { + path, err := os.Getwd() + if err != nil { + return nil, err + } + ns, err := namespaces.NamespaceRequired(ctx) + if err != nil { + return nil, err + } + runtime, err := s.readRuntime(path) + if err != nil { + return nil, err + } + rs := proc.NewRunsc(s.opts.Root, path, ns, runtime, nil) + stats, err := rs.Stats(ctx, s.id) + if err != nil { + return nil, err + } + + // gvisor currently (as of 2020-03-03) only returns the total memory + // usage and current PID value[0]. However, we copy the common fields here + // so that future updates will propagate correct information. We're + // using the cgroups.Metrics structure so we're returning the same type + // as runc. + // + // [0]: https://github.com/google/gvisor/blob/277a0d5a1fbe8272d4729c01ee4c6e374d047ebc/runsc/boot/events.go#L61-L81 + data, err := typeurl.MarshalAny(&cgroups.Metrics{ + CPU: &cgroups.CPUStat{ + Usage: &cgroups.CPUUsage{ + Total: stats.Cpu.Usage.Total, + Kernel: stats.Cpu.Usage.Kernel, + User: stats.Cpu.Usage.User, + PerCPU: stats.Cpu.Usage.Percpu, + }, + Throttling: &cgroups.Throttle{ + Periods: stats.Cpu.Throttling.Periods, + ThrottledPeriods: stats.Cpu.Throttling.ThrottledPeriods, + ThrottledTime: stats.Cpu.Throttling.ThrottledTime, + }, + }, + Memory: &cgroups.MemoryStat{ + Cache: stats.Memory.Cache, + Usage: &cgroups.MemoryEntry{ + Limit: stats.Memory.Usage.Limit, + Usage: stats.Memory.Usage.Usage, + Max: stats.Memory.Usage.Max, + Failcnt: stats.Memory.Usage.Failcnt, + }, + Swap: &cgroups.MemoryEntry{ + Limit: stats.Memory.Swap.Limit, + Usage: stats.Memory.Swap.Usage, + Max: stats.Memory.Swap.Max, + Failcnt: stats.Memory.Swap.Failcnt, + }, + Kernel: &cgroups.MemoryEntry{ + Limit: stats.Memory.Kernel.Limit, + Usage: stats.Memory.Kernel.Usage, + Max: stats.Memory.Kernel.Max, + Failcnt: stats.Memory.Kernel.Failcnt, + }, + KernelTCP: &cgroups.MemoryEntry{ + Limit: stats.Memory.KernelTCP.Limit, + Usage: stats.Memory.KernelTCP.Usage, + Max: stats.Memory.KernelTCP.Max, + Failcnt: stats.Memory.KernelTCP.Failcnt, + }, + }, + Pids: &cgroups.PidsStat{ + Current: stats.Pids.Current, + Limit: stats.Pids.Limit, + }, + }) + if err != nil { + return nil, err + } + return &taskAPI.StatsResponse{ + Stats: data, + }, nil +} + +// Update updates a running container. +func (s *service) Update(ctx context.Context, r *taskAPI.UpdateTaskRequest) (*types.Empty, error) { + return empty, errdefs.ToGRPC(errdefs.ErrNotImplemented) +} + +// Wait waits for a process to exit. +func (s *service) Wait(ctx context.Context, r *taskAPI.WaitRequest) (*taskAPI.WaitResponse, error) { + p, err := s.getProcess(r.ExecID) + if err != nil { + return nil, err + } + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrFailedPrecondition, "container must be created") + } + p.Wait() + + return &taskAPI.WaitResponse{ + ExitStatus: uint32(p.ExitStatus()), + ExitedAt: p.ExitedAt(), + }, nil +} + +func (s *service) processExits() { + for e := range s.ec { + s.checkProcesses(e) + } +} + +func (s *service) checkProcesses(e proc.Exit) { + // TODO(random-liu): Add `shouldKillAll` logic if container pid + // namespace is supported. + for _, p := range s.allProcesses() { + if p.ID() == e.ID { + if ip, ok := p.(*proc.Init); ok { + // Ensure all children are killed. + if err := ip.KillAll(s.context); err != nil { + log.G(s.context).WithError(err).WithField("id", ip.ID()). + Error("failed to kill init's children") + } + } + p.SetExited(e.Status) + s.events <- &events.TaskExit{ + ContainerID: s.id, + ID: p.ID(), + Pid: uint32(p.Pid()), + ExitStatus: uint32(e.Status), + ExitedAt: p.ExitedAt(), + } + return + } + } +} + +func (s *service) allProcesses() (o []process.Process) { + s.mu.Lock() + defer s.mu.Unlock() + for _, p := range s.processes { + o = append(o, p) + } + if s.task != nil { + o = append(o, s.task) + } + return o +} + +func (s *service) getContainerPids(ctx context.Context, id string) ([]uint32, error) { + s.mu.Lock() + p := s.task + s.mu.Unlock() + if p == nil { + return nil, fmt.Errorf("container must be created: %w", errdefs.ErrFailedPrecondition) + } + ps, err := p.(*proc.Init).Runtime().Ps(ctx, id) + if err != nil { + return nil, err + } + pids := make([]uint32, 0, len(ps)) + for _, pid := range ps { + pids = append(pids, uint32(pid)) + } + return pids, nil +} + +func (s *service) forward(publisher shim.Publisher) { + for e := range s.events { + ctx, cancel := context.WithTimeout(s.context, 5*time.Second) + err := publisher.Publish(ctx, getTopic(e), e) + cancel() + if err != nil { + // Should not happen. + panic(fmt.Errorf("post event: %w", err)) + } + } +} + +func (s *service) getProcess(execID string) (process.Process, error) { + s.mu.Lock() + defer s.mu.Unlock() + if execID == "" { + return s.task, nil + } + p := s.processes[execID] + if p == nil { + return nil, errdefs.ToGRPCf(errdefs.ErrNotFound, "process does not exist %s", execID) + } + return p, nil +} + +func getTopic(e interface{}) string { + switch e.(type) { + case *events.TaskCreate: + return runtime.TaskCreateEventTopic + case *events.TaskStart: + return runtime.TaskStartEventTopic + case *events.TaskOOM: + return runtime.TaskOOMEventTopic + case *events.TaskExit: + return runtime.TaskExitEventTopic + case *events.TaskDelete: + return runtime.TaskDeleteEventTopic + case *events.TaskExecAdded: + return runtime.TaskExecAddedEventTopic + case *events.TaskExecStarted: + return runtime.TaskExecStartedEventTopic + default: + log.L.Printf("no topic for type %#v", e) + } + return runtime.TaskUnknownTopic +} + +func newInit(ctx context.Context, path, workDir, namespace string, platform stdio.Platform, r *proc.CreateConfig, options *options.Options, rootfs string) (*proc.Init, error) { + spec, err := utils.ReadSpec(r.Bundle) + if err != nil { + return nil, fmt.Errorf("read oci spec: %w", err) + } + if err := utils.UpdateVolumeAnnotations(r.Bundle, spec); err != nil { + return nil, fmt.Errorf("update volume annotations: %w", err) + } + runsc.FormatLogPath(r.ID, options.RunscConfig) + runtime := proc.NewRunsc(options.Root, path, namespace, options.BinaryName, options.RunscConfig) + p := proc.New(r.ID, runtime, stdio.Stdio{ + Stdin: r.Stdin, + Stdout: r.Stdout, + Stderr: r.Stderr, + Terminal: r.Terminal, + }) + p.Bundle = r.Bundle + p.Platform = platform + p.Rootfs = rootfs + p.WorkDir = workDir + p.IoUID = int(options.IoUid) + p.IoGID = int(options.IoGid) + p.Sandbox = specutils.SpecContainerType(spec) == specutils.ContainerTypeSandbox + p.UserLog = utils.UserLogPath(spec) + p.Monitor = reaper.Default + return p, nil +} diff --git a/pkg/shim/v2/service_linux.go b/pkg/shim/v2/service_linux.go new file mode 100644 index 000000000..1800ab90b --- /dev/null +++ b/pkg/shim/v2/service_linux.go @@ -0,0 +1,108 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package v2 + +import ( + "context" + "fmt" + "io" + "sync" + "syscall" + + "github.com/containerd/console" + "github.com/containerd/fifo" +) + +type linuxPlatform struct { + epoller *console.Epoller +} + +func (p *linuxPlatform) CopyConsole(ctx context.Context, console console.Console, stdin, stdout, stderr string, wg *sync.WaitGroup) (console.Console, error) { + if p.epoller == nil { + return nil, fmt.Errorf("uninitialized epoller") + } + + epollConsole, err := p.epoller.Add(console) + if err != nil { + return nil, err + } + + if stdin != "" { + in, err := fifo.OpenFifo(context.Background(), stdin, syscall.O_RDONLY|syscall.O_NONBLOCK, 0) + if err != nil { + return nil, err + } + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(epollConsole, in, *p) + }() + } + + outw, err := fifo.OpenFifo(ctx, stdout, syscall.O_WRONLY, 0) + if err != nil { + return nil, err + } + outr, err := fifo.OpenFifo(ctx, stdout, syscall.O_RDONLY, 0) + if err != nil { + return nil, err + } + wg.Add(1) + go func() { + p := bufPool.Get().(*[]byte) + defer bufPool.Put(p) + io.CopyBuffer(outw, epollConsole, *p) + epollConsole.Close() + outr.Close() + outw.Close() + wg.Done() + }() + return epollConsole, nil +} + +func (p *linuxPlatform) ShutdownConsole(ctx context.Context, cons console.Console) error { + if p.epoller == nil { + return fmt.Errorf("uninitialized epoller") + } + epollConsole, ok := cons.(*console.EpollConsole) + if !ok { + return fmt.Errorf("expected EpollConsole, got %#v", cons) + } + return epollConsole.Shutdown(p.epoller.CloseConsole) +} + +func (p *linuxPlatform) Close() error { + return p.epoller.Close() +} + +// initialize a single epoll fd to manage our consoles. `initPlatform` should +// only be called once. +func (s *service) initPlatform() error { + if s.platform != nil { + return nil + } + epoller, err := console.NewEpoller() + if err != nil { + return fmt.Errorf("failed to initialize epoller: %w", err) + } + s.platform = &linuxPlatform{ + epoller: epoller, + } + go epoller.Wait() + return nil +} diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD index e131455f7..ae0fe1522 100644 --- a/pkg/sleep/BUILD +++ b/pkg/sleep/BUILD @@ -12,6 +12,7 @@ go_library( "sleep_unsafe.go", ], visibility = ["//:sandbox"], + deps = ["//pkg/sync"], ) go_test( diff --git a/pkg/sleep/sleep_test.go b/pkg/sleep/sleep_test.go index af47e2ba1..1dd11707d 100644 --- a/pkg/sleep/sleep_test.go +++ b/pkg/sleep/sleep_test.go @@ -379,10 +379,7 @@ func TestRace(t *testing.T) { // TestRaceInOrder tests that multiple wakers can continuously send wake requests to // the sleeper and that the wakers are retrieved in the order asserted. func TestRaceInOrder(t *testing.T) { - const wakers = 100 - const wakeRequests = 10000 - - w := make([]Waker, wakers) + w := make([]Waker, 10000) s := Sleeper{} // Associate each waker and start goroutines that will assert them. @@ -390,19 +387,16 @@ func TestRaceInOrder(t *testing.T) { s.AddWaker(&w[i], i) } go func() { - n := 0 - for n < wakeRequests { - wk := w[n%len(w)] - wk.Assert() - n++ + for i := range w { + w[i].Assert() } }() // Wait for all wake up notifications from all wakers. - for i := 0; i < wakeRequests; i++ { - v, _ := s.Fetch(true) - if got, want := v, i%wakers; got != want { - t.Fatalf("got %d want %d", got, want) + for want := range w { + got, _ := s.Fetch(true) + if got != want { + t.Fatalf("got %d want %d", got, want) } } } diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index f68c12620..118805492 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -75,6 +75,8 @@ package sleep import ( "sync/atomic" "unsafe" + + "gvisor.dev/gvisor/pkg/sync" ) const ( @@ -323,7 +325,12 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) { // // This struct is thread-safe, that is, its methods can be called concurrently // by multiple goroutines. +// +// Note, it is not safe to copy a Waker as its fields are modified by value +// (the pointer fields are individually modified with atomic operations). type Waker struct { + _ sync.NoCopy + // s is the sleeper that this waker can wake up. Only one sleeper at a // time is allowed. This field can have three classes of values: // nil -- the waker is not asserted: it either is not associated with diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 2b1350135..089b3bbef 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -1,9 +1,47 @@ -load("//tools:defs.bzl", "go_library", "go_test", "proto_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( + name = "pending_list", + out = "pending_list.go", + package = "state", + prefix = "pending", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "pendingMapper", + "Linker": "*pendingEntry", + }, +) + +go_template_instance( + name = "deferred_list", + out = "deferred_list.go", + package = "state", + prefix = "deferred", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectEncodeState", + "ElementMapper": "deferredMapper", + "Linker": "*deferredEntry", + }, +) + +go_template_instance( + name = "complete_list", + out = "complete_list.go", + package = "state", + prefix = "complete", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*objectDecodeState", + "Linker": "*objectDecodeState", + }, +) + +go_template_instance( name = "addr_range", out = "addr_range.go", package = "state", @@ -29,7 +67,7 @@ go_template_instance( types = { "Key": "uintptr", "Range": "addrRange", - "Value": "reflect.Value", + "Value": "*objectEncodeState", "Functions": "addrSetFunctions", }, ) @@ -39,32 +77,24 @@ go_library( srcs = [ "addr_range.go", "addr_set.go", + "complete_list.go", "decode.go", + "decode_unsafe.go", + "deferred_list.go", "encode.go", "encode_unsafe.go", - "map.go", - "printer.go", + "pending_list.go", "state.go", + "state_norace.go", + "state_race.go", "stats.go", + "types.go", ], marshal = False, stateify = False, visibility = ["//:sandbox"], deps = [ - ":object_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", + "//pkg/log", + "//pkg/state/wire", ], ) - -proto_library( - name = "object", - srcs = ["object.proto"], - visibility = ["//:sandbox"], -) - -go_test( - name = "state_test", - timeout = "long", - srcs = ["state_test.go"], - library = ":state", -) diff --git a/pkg/state/README.md b/pkg/state/README.md new file mode 100644 index 000000000..1aa401193 --- /dev/null +++ b/pkg/state/README.md @@ -0,0 +1,158 @@ +# State Encoding and Decoding + +The state package implements the encoding and decoding of data structures for +`go_stateify`. This package is designed for use cases other than the standard +encoding packages, e.g. `gob` and `json`. Principally: + +* This package operates on complex object graphs and accurately serializes and + restores all relationships. That is, you can have things like: intrusive + pointers, cycles, and pointer chains of arbitrary depths. These are not + handled appropriately by existing encoders. This is not an implementation + flaw: the formats themselves are not capable of representing these graphs, + as they can only generate directed trees. + +* This package allows installing order-dependent load callbacks and then + resolves that graph at load time, with cycle detection. Similarly, there is + no analogous feature possible in the standard encoders. + +* This package handles the resolution of interfaces, based on a registered + type name. For interface objects type information is saved in the serialized + format. This is generally true for `gob` as well, but it works differently. + +Here's an overview of how encoding and decoding works. + +## Encoding + +Encoding produces a `statefile`, which contains a list of chunks of the form +`(header, payload)`. The payload can either be some raw data, or a series of +encoded wire objects representing some object graph. All encoded objects are +defined in the `wire` subpackage. + +Encoding of an object graph begins with `encodeState.Save`. + +### 1. Memory Map & Encoding + +To discover relationships between potentially interdependent data structures +(for example, a struct may contain pointers to members of other data +structures), the encoder first walks the object graph and constructs a memory +map of the objects in the input graph. As this walk progresses, objects are +queued in the `pending` list and items are placed on the `deferred` list as they +are discovered. No single object will be encoded multiple times, but the +discovered relationships between objects may change as more parts of the overall +object graph are discovered. + +The encoder starts at the root object and recursively visits all reachable +objects, recording the address ranges containing the underlying data for each +object. This is stored as a segment set (`addrSet`), mapping address ranges to +the of the object occupying the range; see `encodeState.values`. Note that there +is special handling for zero-sized types and map objects during this process. + +Additionally, the encoder assigns each object a unique identifier which is used +to indicate relationships between objects in the statefile; see `objectID` in +`encode.go`. + +### 2. Type Serialization + +The enoder will subsequently serialize all information about discovered types, +including field names. These are used during decoding to reconcile these types +with other internally registered types. + +### 3. Object Serialization + +With a full address map, and all objects correctly encoded, all object encodings +are serialized. The assigned `objectID`s aren't explicitly encoded in the +statefile. The order of object messages in the stream determine their IDs. + +### Example + +Given the following data structure definitions: + +```go +type system struct { + o *outer + i *inner +} + +type outer struct { + a int64 + cn *container +} + +type container struct { + n uint64 + elem *inner +} + +type inner struct { + c container + x, y uint64 +} +``` + +Initialized like this: + +```go +o := outer{ + a: 10, + cn: nil, +} +i := inner{ + x: 20, + y: 30, + c: container{}, +} +s := system{ + o: &o, + i: &i, +} + +o.cn = &i.c +o.cn.elem = &i + +``` + +Encoding will produce an object stream like this: + +``` +g0r1 = struct{ + i: g0r3, + o: g0r2, +} +g0r2 = struct{ + a: 10, + cn: g0r3.c, +} +g0r3 = struct{ + c: struct{ + elem: g0r3, + n: 0u, + }, + x: 20u, + y: 30u, +} +``` + +Note how `g0r3.c` is correctly encoded as the underlying `container` object for +`inner.c`, and how the pointer from `outer.cn` points to it, despite `system.i` +being discovered after the pointer to it in `system.o.cn`. Also note that +decoding isn't strictly reliant on the order of encoded object stream, as long +as the relationship between objects are correctly encoded. + +## Decoding + +Decoding reads the statefile and reconstructs the object graph. Decoding begins +in `decodeState.Load`. Decoding is performed in a single pass over the object +stream in the statefile, and a subsequent pass over all deserialized objects is +done to fire off all loading callbacks in the correctly defined order. Note that +introducing cycles is possible here, but these are detected and an error will be +returned. + +Decoding is relatively straight forward. For most primitive values, the decoder +constructs an appropriate object and fills it with the values encoded in the +statefile. Pointers need special handling, as they must point to a value +allocated elsewhere. When values are constructed, the decoder indexes them by +their `objectID`s in `decodeState.objectsByID`. The target of pointers are +resolved by searching for the target in this index by their `objectID`; see +`decodeState.register`. For pointers to values inside another value (fields in a +pointer, elements of an array), the decoder uses the accessor path to walk to +the appropriate location; see `walkChild`. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 590c241a3..c9971cdf6 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -17,28 +17,49 @@ package state import ( "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" + "math" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// objectState represents an object that may be in the process of being +// internalCallback is a interface called on object completion. +// +// There are two implementations: objectDecodeState & userCallback. +type internalCallback interface { + // source returns the dependent object. May be nil. + source() *objectDecodeState + + // callbackRun executes the callback. + callbackRun() +} + +// userCallback is an implementation of internalCallback. +type userCallback func() + +// source implements internalCallback.source. +func (userCallback) source() *objectDecodeState { + return nil +} + +// callbackRun implements internalCallback.callbackRun. +func (uc userCallback) callbackRun() { + uc() +} + +// objectDecodeState represents an object that may be in the process of being // decoded. Specifically, it represents either a decoded object, or an an // interest in a future object that will be decoded. When that interest is // registered (via register), the storage for the object will be created, but // it will not be decoded until the object is encountered in the stream. -type objectState struct { +type objectDecodeState struct { // id is the id for this object. - // - // If this field is zero, then this is an anonymous (unregistered, - // non-reference primitive) object. This is immutable. - id uint64 + id objectID + + // typ is the id for this typeID. This may be zero if this is not a + // type-registered structure. + typ typeID // obj is the object. This may or may not be valid yet, depending on // whether complete returns true. However, regardless of whether the @@ -57,69 +78,52 @@ type objectState struct { // blockedBy is the number of dependencies this object has. blockedBy int - // blocking is a list of the objects blocked by this one. - blocking []*objectState + // callbacksInline is inline storage for callbacks. + callbacksInline [2]internalCallback // callbacks is a set of callbacks to execute on load. - callbacks []func() - - // path is the decoding path to the object. - path recoverable -} - -// complete indicates the object is complete. -func (os *objectState) complete() bool { - return os.blockedBy == 0 && len(os.callbacks) == 0 -} - -// checkComplete checks for completion. If the object is complete, pending -// callbacks will be executed and checkComplete will be called on downstream -// objects (those depending on this one). -func (os *objectState) checkComplete(stats *Stats) { - if os.blockedBy > 0 { - return - } - stats.Start(os.obj) + callbacks []internalCallback - // Fire all callbacks. - for _, fn := range os.callbacks { - fn() - } - os.callbacks = nil - - // Clear all blocked objects. - for _, other := range os.blocking { - other.blockedBy-- - other.checkComplete(stats) - } - os.blocking = nil - stats.Done() + completeEntry } -// waitFor queues a dependency on the given object. -func (os *objectState) waitFor(other *objectState, callback func()) { - os.blockedBy++ - other.blocking = append(other.blocking, os) - if callback != nil { - other.callbacks = append(other.callbacks, callback) +// addCallback adds a callback to the objectDecodeState. +func (ods *objectDecodeState) addCallback(ic internalCallback) { + if ods.callbacks == nil { + ods.callbacks = ods.callbacksInline[:0] } + ods.callbacks = append(ods.callbacks, ic) } // findCycleFor returns when the given object is found in the blocking set. -func (os *objectState) findCycleFor(target *objectState) []*objectState { - for _, other := range os.blocking { - if other == target { - return []*objectState{target} +func (ods *objectDecodeState) findCycleFor(target *objectDecodeState) []*objectDecodeState { + for _, ic := range ods.callbacks { + other := ic.source() + if other != nil && other == target { + return []*objectDecodeState{target} } else if childList := other.findCycleFor(target); childList != nil { return append(childList, other) } } - return nil + + // This should not occur. + Failf("no deadlock found?") + panic("unreachable") } // findCycle finds a dependency cycle. -func (os *objectState) findCycle() []*objectState { - return append(os.findCycleFor(os), os) +func (ods *objectDecodeState) findCycle() []*objectDecodeState { + return append(ods.findCycleFor(ods), ods) +} + +// source implements internalCallback.source. +func (ods *objectDecodeState) source() *objectDecodeState { + return ods +} + +// callbackRun implements internalCallback.callbackRun. +func (ods *objectDecodeState) callbackRun() { + ods.blockedBy-- } // decodeState is a graph of objects in the process of being decoded. @@ -137,30 +141,66 @@ type decodeState struct { // ctx is the decode context. ctx context.Context + // r is the input stream. + r wire.Reader + + // types is the type database. + types typeDecodeDatabase + // objectByID is the set of objects in progress. - objectsByID map[uint64]*objectState + objectsByID []*objectDecodeState // deferred are objects that have been read, by no interest has been // registered yet. These will be decoded once interest in registered. - deferred map[uint64]*pb.Object + deferred map[objectID]wire.Object - // outstanding is the number of outstanding objects. - outstanding uint32 + // pending is the set of objects that are not yet complete. + pending completeList - // r is the input stream. - r io.Reader - - // stats is the passed stats object. - stats *Stats - - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } // lookup looks up an object in decodeState or returns nil if no such object // has been previously registered. -func (ds *decodeState) lookup(id uint64) *objectState { - return ds.objectsByID[id] +func (ds *decodeState) lookup(id objectID) *objectDecodeState { + if len(ds.objectsByID) < int(id) { + return nil + } + return ds.objectsByID[id-1] +} + +// checkComplete checks for completion. +func (ds *decodeState) checkComplete(ods *objectDecodeState) bool { + // Still blocked? + if ods.blockedBy > 0 { + return false + } + + // Track stats if relevant. + if ods.callbacks != nil && ods.typ != 0 { + ds.stats.start(ods.typ) + defer ds.stats.done() + } + + // Fire all callbacks. + for _, ic := range ods.callbacks { + ic.callbackRun() + } + + // Mark completed. + cbs := ods.callbacks + ods.callbacks = nil + ds.pending.Remove(ods) + + // Recursively check others. + for _, ic := range cbs { + if other := ic.source(); other != nil && other.blockedBy == 0 { + ds.checkComplete(other) + } + } + + return true // All set. } // wait registers a dependency on an object. @@ -168,11 +208,8 @@ func (ds *decodeState) lookup(id uint64) *objectState { // As a special case, we always allow _useable_ references back to the first // decoding object because it may have fields that are already decoded. We also // allow trivial self reference, since they can be handled internally. -func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { +func (ds *decodeState) wait(waiter *objectDecodeState, id objectID, callback func()) { switch id { - case 0: - // Nil pointer; nothing to wait for. - fallthrough case waiter.id: // Trivial self reference. fallthrough @@ -184,107 +221,188 @@ func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) { return } + // Mark as blocked. + waiter.blockedBy++ + // No nil can be returned here. - waiter.waitFor(ds.lookup(id), callback) + other := ds.lookup(id) + if callback != nil { + // Add the additional user callback. + other.addCallback(userCallback(callback)) + } + + // Mark waiter as unblocked. + other.addCallback(waiter) } // waitObject notes a blocking relationship. -func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) { - if rv, ok := p.Value.(*pb.Object_RefValue); ok { +func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, callback func()) { + if rv, ok := encoded.(*wire.Ref); ok && rv.Root != 0 { // Refs can encode pointers and maps. - ds.wait(os, rv.RefValue, callback) - } else if sv, ok := p.Value.(*pb.Object_SliceValue); ok { + ds.wait(ods, objectID(rv.Root), callback) + } else if sv, ok := encoded.(*wire.Slice); ok && sv.Ref.Root != 0 { // See decodeObject; we need to wait for the array (if non-nil). - ds.wait(os, sv.SliceValue.RefValue, callback) - } else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok { + ds.wait(ods, objectID(sv.Ref.Root), callback) + } else if iv, ok := encoded.(*wire.Interface); ok { // It's an interface (wait recurisvely). - ds.waitObject(os, iv.InterfaceValue.Value, callback) + ds.waitObject(ods, iv.Value, callback) } else if callback != nil { // Nothing to wait for: execute the callback immediately. callback() } } +// walkChild returns a child object from obj, given an accessor path. This is +// the decode-side equivalent to traverse in encode.go. +// +// For the purposes of this function, a child object is either a field within a +// struct or an array element, with one such indirection per element in +// path. The returned value may be an unexported field, so it may not be +// directly assignable. See unsafePointerTo. +func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { + // See wire.Ref.Dots. The path here is specified in reverse order. + for i := len(path) - 1; i >= 0; i-- { + switch pc := path[i].(type) { + case *wire.FieldName: // Must be a pointer. + if obj.Kind() != reflect.Struct { + Failf("next component in child path is a field name, but the current object is not a struct. Path: %v, current obj: %#v", path, obj) + } + obj = obj.FieldByName(string(*pc)) + case wire.Index: // Embedded. + if obj.Kind() != reflect.Array { + Failf("next component in child path is an array index, but the current object is not an array. Path: %v, current obj: %#v", path, obj) + } + obj = obj.Index(int(pc)) + default: + panic("unreachable: switch should be exhaustive") + } + } + return obj +} + // register registers a decode with a type. // // This type is only used to instantiate a new object if it has not been -// registered previously. -func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState { - os, ok := ds.objectsByID[id] - if ok { - return os +// registered previously. This depends on the type provided if none is +// available in the object itself. +func (ds *decodeState) register(r *wire.Ref, typ reflect.Type) reflect.Value { + // Grow the objectsByID slice. + id := objectID(r.Root) + if len(ds.objectsByID) < int(id) { + ds.objectsByID = append(ds.objectsByID, make([]*objectDecodeState, int(id)-len(ds.objectsByID))...) + } + + // Does this object already exist? + ods := ds.objectsByID[id-1] + if ods != nil { + return walkChild(r.Dots, ods.obj) + } + + // Create the object. + if len(r.Dots) != 0 { + typ = ds.findType(r.Type) } + v := reflect.New(typ) + ods = &objectDecodeState{ + id: id, + obj: v.Elem(), + } + ds.objectsByID[id-1] = ods + ds.pending.PushBack(ods) - // Record in the object index. - if typ.Kind() == reflect.Map { - os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()} - } else { - os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()} + // Process any deferred objects & callbacks. + if encoded, ok := ds.deferred[id]; ok { + delete(ds.deferred, id) + ds.decodeObject(ods, ods.obj, encoded) } - ds.objectsByID[id] = os - if o, ok := ds.deferred[id]; ok { - // There is a deferred object. - delete(ds.deferred, id) // Free memory. - ds.decodeObject(os, os.obj, o, "", nil) - } else { - // There is no deferred object. - ds.outstanding++ + return walkChild(r.Dots, ods.obj) +} + +// objectDecoder is for decoding structs. +type objectDecoder struct { + // ds is decodeState. + ds *decodeState + + // ods is current object being decoded. + ods *objectDecodeState + + // reconciledTypeEntry is the reconciled type information. + rte *reconciledTypeEntry + + // encoded is the encoded object state. + encoded *wire.Struct +} + +// load is helper for the public methods on Source. +func (od *objectDecoder) load(slot int, objPtr reflect.Value, wait bool, fn func()) { + // Note that we have reconciled the type and may remap the fields here + // to match what's expected by the decoder. The "slot" parameter here + // is in terms of the local type, where the fields in the encoded + // object are in terms of the wire object's type, which might be in a + // different order (but will have the same fields). + v := *od.encoded.Field(od.rte.FieldOrder[slot]) + od.ds.decodeObject(od.ods, objPtr.Elem(), v) + if wait { + // Mark this individual object a blocker. + od.ds.waitObject(od.ods, v, fn) } +} - return os +// aterLoad implements Source.AfterLoad. +func (od *objectDecoder) afterLoad(fn func()) { + // Queue the local callback; this will execute when all of the above + // data dependencies have been cleared. + od.ods.addCallback(userCallback(fn)) } // decodeStruct decodes a struct value. -func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) { - // Set the fields. - m := Map{newInternalMap(nil, ds, os)} - defer internalMapPool.Put(m.internalMap) - for _, field := range s.Fields { - m.data = append(m.data, entry{ - name: field.Name, - object: field.Value, - }) - } - - // Sort the fields for efficient searching. - // - // Technically, these should already appear in sorted order in the - // state ordering, so this cost is effectively a single scan to ensure - // that the order is correct. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - } - - // Invoke the load; this will recursively decode other objects. - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the loader. - fns.invokeLoad(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow anonymous empty structs. - return - } else { +func (ds *decodeState) decodeStruct(ods *objectDecodeState, obj reflect.Value, encoded *wire.Struct) { + if encoded.TypeID == 0 { + // Allow anonymous empty structs, but only if the encoded + // object also has no fields. + if encoded.Fields() == 0 && obj.NumField() == 0 { + return + } + // Propagate an error. - panic(fmt.Errorf("unregistered type %s", obj.Type())) + Failf("empty struct on wire %#v has field mismatch with type %q", encoded, obj.Type().Name()) + } + + // Lookup the object type. + rte := ds.types.Lookup(typeID(encoded.TypeID), obj.Type()) + ods.typ = typeID(encoded.TypeID) + + // Invoke the loader. + od := objectDecoder{ + ds: ds, + ods: ods, + rte: rte, + encoded: encoded, + } + ds.stats.start(ods.typ) + defer ds.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateLoad(Source{internal: od}) } } // decodeMap decodes a map value. -func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) { +func (ds *decodeState) decodeMap(ods *objectDecodeState, obj reflect.Value, encoded *wire.Map) { if obj.IsNil() { + // See pointerTo. obj.Set(reflect.MakeMap(obj.Type())) } - for i := 0; i < len(m.Keys); i++ { + for i := 0; i < len(encoded.Keys); i++ { // Decode the objects. kv := reflect.New(obj.Type().Key()).Elem() vv := reflect.New(obj.Type().Elem()).Elem() - ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i) - ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface()) - ds.waitObject(os, m.Keys[i], nil) - ds.waitObject(os, m.Values[i], nil) + ds.decodeObject(ods, kv, encoded.Keys[i]) + ds.decodeObject(ods, vv, encoded.Values[i]) + ds.waitObject(ods, encoded.Keys[i], nil) + ds.waitObject(ods, encoded.Values[i], nil) // Set in the map. obj.SetMapIndex(kv, vv) @@ -292,271 +410,294 @@ func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) } // decodeArray decodes an array value. -func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) { - if len(a.Contents) != obj.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents))) +func (ds *decodeState) decodeArray(ods *objectDecodeState, obj reflect.Value, encoded *wire.Array) { + if len(encoded.Contents) != obj.Len() { + Failf("mismatching array length expect=%d, actual=%d", obj.Len(), len(encoded.Contents)) } // Decode the contents into the array. - for i := 0; i < len(a.Contents); i++ { - ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i) - ds.waitObject(os, a.Contents[i], nil) + for i := 0; i < len(encoded.Contents); i++ { + ds.decodeObject(ods, obj.Index(i), encoded.Contents[i]) + ds.waitObject(ods, encoded.Contents[i], nil) } } -// decodeInterface decodes an interface value. -func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) { - // Is this a nil value? - if i.Type == "" { - return // Just leave obj alone. +// findType finds the type for the given wire.TypeSpecs. +func (ds *decodeState) findType(t wire.TypeSpec) reflect.Type { + switch x := t.(type) { + case wire.TypeID: + typ := ds.types.LookupType(typeID(x)) + rte := ds.types.Lookup(typeID(x), typ) + return rte.LocalType + case *wire.TypeSpecPointer: + return reflect.PtrTo(ds.findType(x.Type)) + case *wire.TypeSpecArray: + return reflect.ArrayOf(int(x.Count), ds.findType(x.Type)) + case *wire.TypeSpecSlice: + return reflect.SliceOf(ds.findType(x.Type)) + case *wire.TypeSpecMap: + return reflect.MapOf(ds.findType(x.Key), ds.findType(x.Value)) + default: + // Should not happen. + Failf("unknown type %#v", t) } + panic("unreachable") +} - // Get the dispatchable type. This may not be used if the given - // reference has already been resolved, but if not we need to know the - // type to create. - t, ok := registeredTypes.lookupType(i.Type) - if !ok { - panic(fmt.Errorf("no valid type for %q", i.Type)) +// decodeInterface decodes an interface value. +func (ds *decodeState) decodeInterface(ods *objectDecodeState, obj reflect.Value, encoded *wire.Interface) { + if _, ok := encoded.Type.(wire.TypeSpecNil); ok { + // Special case; the nil object. Just decode directly, which + // will read nil from the wire (if encoded correctly). + ds.decodeObject(ods, obj, encoded.Value) + return } - if obj.Kind() != reflect.Map { - // Set the obj to be the given typed value; this actually sets - // obj to be a non-zero value -- namely, it inserts type - // information. There's no need to do this for maps. - obj.Set(reflect.Zero(t)) + // We now need to resolve the actual type. + typ := ds.findType(encoded.Type) + + // We need to imbue type information here, then we can proceed to + // decode normally. In order to avoid issues with setting value-types, + // we create a new non-interface version of this object. We will then + // set the interface object to be equal to whatever we decode. + origObj := obj + obj = reflect.New(typ).Elem() + defer origObj.Set(obj) + + // With the object now having sufficient type information to actually + // have Set called on it, we can proceed to decode the value. + ds.decodeObject(ods, obj, encoded.Value) +} + +// isFloatEq determines if x and y represent the same value. +func isFloatEq(x float64, y float64) bool { + switch { + case math.IsNaN(x): + return math.IsNaN(y) + case math.IsInf(x, 1): + return math.IsInf(y, 1) + case math.IsInf(x, -1): + return math.IsInf(y, -1) + default: + return x == y } +} - // Decode the dereferenced element; there is no need to wait here, as - // the interface object shares the current object state. - ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type) +// isComplexEq determines if x and y represent the same value. +func isComplexEq(x complex128, y complex128) bool { + return isFloatEq(real(x), real(y)) && isFloatEq(imag(x), imag(y)) } // decodeObject decodes a object value. -func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) { - ds.push(false, format, param) - ds.stats.Add(obj) - ds.stats.Start(obj) - - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - obj.SetBool(x.BoolValue) - case *pb.Object_StringValue: - obj.SetString(string(x.StringValue)) - case *pb.Object_Int64Value: - obj.SetInt(x.Int64Value) - if obj.Int() != x.Int64Value { - panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type())) +func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, encoded wire.Object) { + switch x := encoded.(type) { + case wire.Nil: // Fast path: first. + // We leave obj alone here. That's because if obj represents an + // interface, it may have been imbued with type information in + // decodeInterface, and we don't want to destroy that. + case *wire.Ref: + // Nil pointers may be encoded in a "forceValue" context. For + // those we just leave it alone as the value will already be + // correct (nil). + if id := objectID(x.Root); id == 0 { + return } - case *pb.Object_Uint64Value: - obj.SetUint(x.Uint64Value) - if obj.Uint() != x.Uint64Value { - panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_DoubleValue: - obj.SetFloat(x.DoubleValue) - if obj.Float() != x.DoubleValue { - panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type())) - } - case *pb.Object_RefValue: - // Resolve the pointer itself, even though the object may not - // be decoded yet. You need to use wait() in order to ensure - // that is the case. See wait above, and Map.Barrier. - if id := x.RefValue; id != 0 { - // Decoding the interface should have imparted type - // information, so from this point it's safe to resolve - // and use this dynamic information for actually - // creating the object in register. - // - // (For non-interfaces this is a no-op). - dyntyp := reflect.TypeOf(obj.Interface()) - if dyntyp.Kind() == reflect.Map { - // Remove the map object count here to avoid - // double counting, as this object will be - // counted again when it gets processed later. - // We do not add a reference count as the - // reference is artificial. - ds.stats.Remove(obj) - obj.Set(ds.register(id, dyntyp).obj) - } else if dyntyp.Kind() == reflect.Ptr { - ds.push(true /* dereference */, "", nil) - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) - ds.pop() - } else { - obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr()) + + // Note that if this is a map type, we go through a level of + // indirection to allow for map aliasing. + if obj.Kind() == reflect.Map { + v := ds.register(x, obj.Type()) + if v.IsNil() { + // Note that we don't want to clobber the map + // if has already been decoded by decodeMap. We + // just make it so that we have a consistent + // reference when that eventually does happen. + v.Set(reflect.MakeMap(v.Type())) } - } else { - // We leave obj alone here. That's because if obj - // represents an interface, it may have been embued - // with type information in decodeInterface, and we - // don't want to destroy that information. + obj.Set(v) + return } - case *pb.Object_SliceValue: - // It's okay to slice the array here, since the contents will - // still be provided later on. These semantics are a bit - // strange but they are handled in the Map.Barrier properly. - // - // The special semantics of zero ref apply here too. - if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 { - v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem()) - obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity))) + + // Normal assignment: authoritative only if no dots. + v := ds.register(x, obj.Type().Elem()) + if v.IsValid() { + obj.Set(unsafePointerTo(v)) } - case *pb.Object_ArrayValue: - ds.decodeArray(os, obj, x.ArrayValue) - case *pb.Object_StructValue: - ds.decodeStruct(os, obj, x.StructValue) - case *pb.Object_MapValue: - ds.decodeMap(os, obj, x.MapValue) - case *pb.Object_InterfaceValue: - ds.decodeInterface(os, obj, x.InterfaceValue) - case *pb.Object_ByteArrayValue: - copyArray(obj, reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Uint16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]uint16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Bool: + obj.SetBool(bool(x)) + case wire.Int: + obj.SetInt(int64(x)) + if obj.Int() != int64(x) { + Failf("signed integer truncated from %v to %v", int64(x), obj.Int()) } - for i := range s { - t[i] = uint16(s[i]) + case wire.Uint: + obj.SetUint(uint64(x)) + if obj.Uint() != uint64(x) { + Failf("unsigned integer truncated from %v to %v", uint64(x), obj.Uint()) } - case *pb.Object_Uint32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := x.Int16ArrayValue.Values - t := obj.Slice(0, obj.Len()).Interface().([]int16) - if len(t) != len(s) { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s))) + case wire.Float32: + obj.SetFloat(float64(x)) + case wire.Float64: + obj.SetFloat(float64(x)) + if !isFloatEq(obj.Float(), float64(x)) { + Failf("floating point number truncated from %v to %v", float64(x), obj.Float()) } - for i := range s { - t[i] = int16(s[i]) + case *wire.Complex64: + obj.SetComplex(complex128(*x)) + case *wire.Complex128: + obj.SetComplex(complex128(*x)) + if !isComplexEq(obj.Complex(), complex128(*x)) { + Failf("complex number truncated from %v to %v", complex128(*x), obj.Complex()) } - case *pb.Object_Int32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values)) + case *wire.String: + obj.SetString(string(*x)) + case *wire.Slice: + // See *wire.Ref above; same applies. + if id := objectID(x.Ref.Root); id == 0 { + return + } + // Note that it's fine to slice the array here and assume that + // contents will still be filled in later on. + typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. + v := ds.register(&x.Ref, typ) + obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + case *wire.Array: + ds.decodeArray(ods, obj, x) + case *wire.Struct: + ds.decodeStruct(ods, obj, x) + case *wire.Map: + ds.decodeMap(ods, obj, x) + case *wire.Interface: + ds.decodeInterface(ods, obj, x) default: // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type())) - } - - ds.stats.Done() - ds.pop() -} - -func copyArray(dest reflect.Value, src reflect.Value) { - if dest.Len() != src.Len() { - panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len())) + Failf("unknown object %#v for %q", encoded, obj.Type().Name()) } - reflect.Copy(dest, castSlice(src, dest.Type().Elem())) } -// Deserialize deserializes the object state. +// Load deserializes the object graph rooted at obj. // // This function may panic and should be run in safely(). -func (ds *decodeState) Deserialize(obj reflect.Value) { - ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()} - ds.outstanding = 1 // The root object. +func (ds *decodeState) Load(obj reflect.Value) { + ds.stats.init() + defer ds.stats.fini(func(id typeID) string { + return ds.types.LookupName(id) + }) + + // Create the root object. + ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + id: 1, + obj: obj, + }) + + // Read the number of objects. + lastID, object, err := ReadHeader(ds.r) + if err != nil { + Failf("header error: %w", err) + } + if !object { + Failf("object missing") + } + + // Decode all objects. + var ( + encoded wire.Object + ods *objectDecodeState + id = objectID(1) + tid = typeID(1) + ) + if err := safely(func() { + // Decode all objects in the stream. + // + // Note that the structure of this decoding loop should match + // the raw decoding loop in printer.go. + for id <= objectID(lastID) { + // Unmarshal the object. + encoded = wire.Load(ds.r) + + // Is this a type object? Handle inline. + if wt, ok := encoded.(*wire.Type); ok { + ds.types.Register(wt) + tid++ + encoded = nil + continue + } - // Decode all objects in the stream. - // - // See above, we never process objects while we have no outstanding - // interests (other than the very first object). - for id := uint64(1); ds.outstanding > 0; id++ { - os := ds.lookup(id) - ds.stats.Start(os.obj) - - o, err := ds.readObject() - if err != nil { - panic(err) - } + // Actually resolve the object. + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } - if os != nil { - // Decode the object. - ds.from = &os.path - ds.decodeObject(os, os.obj, o, "", nil) - ds.outstanding-- + // For error handling. + ods = nil + encoded = nil + id++ + } + }); err != nil { + // Include as much information as we can, taking into account + // the possible state transitions above. + if ods != nil { + Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) + } else if encoded != nil { + Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) } else { - // If an object hasn't had interest registered - // previously, we deferred decoding until interest is - // registered. - ds.deferred[id] = o + Failf("general decoding error: %w", err) } - - ds.stats.Done() - } - - // Check the zero-length header at the end. - length, object, err := ReadHeader(ds.r) - if err != nil { - panic(err) - } - if length != 0 { - panic(fmt.Sprintf("expected zero-length terminal, got %d", length)) - } - if object { - panic("expected non-object terminal") } // Check if we have any deferred objects. - if count := len(ds.deferred); count > 0 { - // Shoud not happen, not propagated as an error. - panic(fmt.Sprintf("still have %d deferred objects", count)) - } - - // Scan and fire all callbacks. - for _, os := range ds.objectsByID { - os.checkComplete(ds.stats) + for id, encoded := range ds.deferred { + // Shoud never happen, the graph was bogus. + Failf("still have deferred objects: one is ID %d, %#v", id, encoded) } - // Check if we have any remaining dependency cycles. - for _, os := range ds.objectsByID { - if !os.complete() { - // This must be the result of a dependency cycle. - cycle := os.findCycle() - var buf bytes.Buffer - buf.WriteString("dependency cycle: {") - for i, cycleOS := range cycle { - if i > 0 { - buf.WriteString(" => ") + // Scan and fire all callbacks. We iterate over the list of incomplete + // objects until all have been finished. We stop iterating if no + // objects become complete (there is a dependency cycle). + // + // Note that we iterate backwards here, because there will be a strong + // tendendcy for blocking relationships to go from earlier objects to + // later (deeper) objects in the graph. This will reduce the number of + // iterations required to finish all objects. + if err := safely(func() { + for ds.pending.Back() != nil { + thisCycle := false + for ods = ds.pending.Back(); ods != nil; { + if ds.checkComplete(ods) { + thisCycle = true + break } - buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type())) + ods = ods.Prev() + } + if !thisCycle { + break } - buf.WriteString("}") - // Panic as an error; propagate to the caller. - panic(errors.New(string(buf.Bytes()))) } - } -} - -type byteReader struct { - io.Reader -} - -// ReadByte implements io.ByteReader. -func (br byteReader) ReadByte() (byte, error) { - var b [1]byte - n, err := br.Reader.Read(b[:]) - if n > 0 { - return b[0], nil - } else if err != nil { - return 0, err - } else { - return 0, io.ErrUnexpectedEOF + }); err != nil { + Failf("error executing callbacks for %#v: %w", ods.obj.Interface(), err) + } + + // Check if we have any remaining dependency cycles. If there are any + // objects left in the pending list, then it must be due to a cycle. + if ods := ds.pending.Front(); ods != nil { + // This must be the result of a dependency cycle. + cycle := ods.findCycle() + var buf bytes.Buffer + buf.WriteString("dependency cycle: {") + for i, cycleOS := range cycle { + if i > 0 { + buf.WriteString(" => ") + } + fmt.Fprintf(&buf, "%q", cycleOS.obj.Type()) + } + buf.WriteString("}") + Failf("incomplete graph: %s", string(buf.Bytes())) } } @@ -565,45 +706,20 @@ func (br byteReader) ReadByte() (byte, error) { // Each object written to the statefile is prefixed with a header. See // WriteHeader for more information; these functions are exported to allow // non-state writes to the file to play nice with debugging tools. -func ReadHeader(r io.Reader) (length uint64, object bool, err error) { +func ReadHeader(r wire.Reader) (length uint64, object bool, err error) { // Read the header. - length, err = binary.ReadUvarint(byteReader{r}) + err = safely(func() { + length = wire.LoadUint(r) + }) if err != nil { - return + // On the header, pass raw I/O errors. + if sErr, ok := err.(*ErrState); ok { + return 0, false, sErr.Unwrap() + } } // Decode whether the object is valid. - object = length&0x1 != 0 - length = length >> 1 + object = length&objectFlag != 0 + length &^= objectFlag return } - -// readObject reads an object from the stream. -func (ds *decodeState) readObject() (*pb.Object, error) { - // Read the header. - length, object, err := ReadHeader(ds.r) - if err != nil { - return nil, err - } - if !object { - return nil, fmt.Errorf("invalid object header") - } - - // Read the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := ds.r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return nil, err - } - } - - // Unmarshal. - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return nil, err - } - - return obj, nil -} diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go new file mode 100644 index 000000000..d048f61a1 --- /dev/null +++ b/pkg/state/decode_unsafe.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "unsafe" +) + +// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on +// values representing unexported fields. This bypasses visibility, but not +// type safety. +func unsafePointerTo(obj reflect.Value) reflect.Value { + return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index c5118d3a9..92fcad4e9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -15,437 +15,797 @@ package state import ( - "container/list" "context" - "encoding/binary" - "fmt" - "io" "reflect" - "sort" - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) -// queuedObject is an object queued for encoding. -type queuedObject struct { - id uint64 - obj reflect.Value - path recoverable +// objectEncodeState the type and identity of an object occupying a memory +// address range. This is the value type for addrSet, and the intrusive entry +// for the pending and deferred lists. +type objectEncodeState struct { + // id is the assigned ID for this object. + id objectID + + // obj is the object value. Note that this may be replaced if we + // encounter an object that contains this object. When this happens (in + // resolve), we will update existing references approprately, below, + // and defer a re-encoding of the object. + obj reflect.Value + + // encoded is the encoded value of this object. Note that this may not + // be up to date if this object is still in the deferred list. + encoded wire.Object + + // how indicates whether this object should be encoded as a value. This + // is used only for deferred encoding. + how encodeStrategy + + // refs are the list of reference objects used by other objects + // referring to this object. When the object is updated, these + // references may be updated directly and automatically. + refs []*wire.Ref + + pendingEntry + deferredEntry } // encodeState is state used for encoding. // -// The encoding process is a breadth-first traversal of the object graph. The -// inherent races and dependencies are much simpler than the decode case. +// The encoding process constructs a representation of the in-memory graph of +// objects before a single object is serialized. This is done to ensure that +// all references can be fully disambiguated. See resolve for more details. type encodeState struct { // ctx is the encode context. ctx context.Context - // lastID is the last object ID. - // - // See idsByObject for context. Because of the special zero encoding - // used for reference values, the first ID must be 1. - lastID uint64 + // w is the output stream. + w wire.Writer - // idsByObject is a set of objects, indexed via: - // - // reflect.ValueOf(x).UnsafeAddr - // - // This provides IDs for objects. - idsByObject map[uintptr]uint64 + // types is the type database. + types typeEncodeDatabase + + // lastID is the last allocated object ID. + lastID objectID - // values stores values that span the addresses. + // values tracks the address ranges occupied by objects, along with the + // types of these objects. This is used to locate pointer targets, + // including pointers to fields within another type. // - // addrSet is a a generated type which efficiently stores ranges of - // addresses. When encoding pointers, these ranges are filled in and - // used to check for overlapping or conflicting pointers. This would - // indicate a pointer to an field, or a non-type safe value, neither of - // which are currently decodable. + // Multiple objects may overlap in memory iff the larger object fully + // contains the smaller one, and the type of the smaller object matches + // a field or array element's type at the appropriate offset. An + // arbitrary number of objects may be nested in this manner. // - // See the usage of values below for more context. + // Note that this does not track zero-sized objects, those are tracked + // by zeroValues below. values addrSet - // w is the output stream. - w io.Writer + // zeroValues tracks zero-sized objects. + zeroValues map[reflect.Type]*objectEncodeState - // pending is the list of objects to be serialized. - // - // This is a set of queuedObjects. - pending list.List + // deferred is the list of objects to be encoded. + deferred deferredList - // done is the a list of finished objects. - // - // This is kept to prevent garbage collection and address reuse. - done list.List + // pendingTypes is the list of types to be serialized. Serialization + // will occur when all objects have been encoded, but before pending is + // serialized. + pendingTypes []wire.Type - // stats is the passed stats object. - stats *Stats + // pending is the list of objects to be serialized. Serialization does + // not actually occur until the full object graph is computed. + pending pendingList - // recoverable is the panic recover facility. - recoverable + // stats tracks time data. + stats Stats } -// register looks up an ID, registering if necessary. +// isSameSizeParent returns true if child is a field value or element within +// parent. Only a struct or array can have a child value. +// +// isSameSizeParent deals with objects like this: +// +// struct child { +// // fields.. +// } // -// If the object was not previously registered, it is enqueued to be serialized. -// See the documentation for idsByObject for more information. -func (es *encodeState) register(obj reflect.Value) uint64 { - // It is not legal to call register for any non-pointer objects (see - // below), so we panic with a recoverable error if this is a mismatch. - if obj.Kind() != reflect.Ptr && obj.Kind() != reflect.Map { - panic(fmt.Errorf("non-pointer %#v registered", obj.Interface())) +// struct parent { +// c child +// } +// +// var p parent +// record(&p.c) +// +// Here, &p and &p.c occupy the exact same address range. +// +// Or like this: +// +// struct child { +// // fields +// } +// +// var arr [1]parent +// record(&arr[0]) +// +// Similarly, &arr[0] and &arr[0].c have the exact same address range. +// +// Precondition: parent and child must occupy the same memory. +func isSameSizeParent(parent reflect.Value, childType reflect.Type) bool { + switch parent.Kind() { + case reflect.Struct: + for i := 0; i < parent.NumField(); i++ { + field := parent.Field(i) + if field.Type() == childType { + return true + } + // Recurse through any intermediate types. + if isSameSizeParent(field, childType) { + return true + } + // Does it make sense to keep going if the first field + // doesn't match? Yes, because there might be an + // arbitrary number of zero-sized fields before we get + // a match, and childType itself can be zero-sized. + } + return false + case reflect.Array: + // The only case where an array with more than one elements can + // return true is if childType is zero-sized. In such cases, + // it's ambiguous which element contains the match since a + // zero-sized child object fully fits in any of the zero-sized + // elements in an array... However since all elements are of + // the same type, we only need to check one element. + // + // For non-zero-sized childTypes, parent.Len() must be 1, but a + // combination of the precondition and an implicit comparison + // between the array element size and childType ensures this. + return parent.Len() > 0 && isSameSizeParent(parent.Index(0), childType) + default: + return false } +} - addr := obj.Pointer() - if obj.Kind() == reflect.Ptr && obj.Elem().Type().Size() == 0 { - // For zero-sized objects, we always provide a unique ID. - // That's because the runtime internally multiplexes pointers - // to the same address. We can't be certain what the intent is - // with pointers to zero-sized objects, so we just give them - // all unique identities. - } else if id, ok := es.idsByObject[addr]; ok { - // Already registered. - return id - } - - // Ensure that the first ID given out is one. See note on lastID. The - // ID zero is used to indicate nil values. +// nextID returns the next valid ID. +func (es *encodeState) nextID() objectID { es.lastID++ - id := es.lastID - es.idsByObject[addr] = id - if obj.Kind() == reflect.Ptr { - // Dereference and treat as a pointer. - es.pending.PushBack(queuedObject{id: id, obj: obj.Elem(), path: es.recoverable.copy()}) - - // Register this object at all addresses. - typ := obj.Elem().Type() - if size := typ.Size(); size > 0 { - r := addrRange{addr, addr + size} - if !es.values.IsEmptyRange(r) { - old := es.values.LowerBoundSegment(addr).Value().Interface().(recoverable) - panic(fmt.Errorf("overlapping objects: [new object] %#v [existing object path] %s", obj.Interface(), old.path())) + return objectID(es.lastID) +} + +// dummyAddr points to the dummy zero-sized address. +var dummyAddr = reflect.ValueOf(new(struct{})).Pointer() + +// resolve records the address range occupied by an object. +func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { + addr := obj.Pointer() + + // Is this a map pointer? Just record the single address. It is not + // possible to take any pointers into the map internals. + if obj.Kind() == reflect.Map { + if addr == 0 { + // Just leave the nil reference alone. This is fine, we + // may need to encode as a reference in this way. We + // return nil for our objectEncodeState so that anyone + // depending on this value knows there's nothing there. + return + } + if seg, _ := es.values.Find(addr); seg.Ok() { + // Ensure the map types match. + existing := seg.Value() + if existing.obj.Type() != obj.Type() { + Failf("overlapping map objects at 0x%x: [new object] %#v [existing object type] %s", addr, obj, existing.obj) } - es.values.Add(r, reflect.ValueOf(es.recoverable.copy())) + + // No sense recording refs, maps may not be replaced by + // covering objects, they are maximal. + ref.Root = wire.Uint(existing.id) + return } + + // Record the map. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + how: encodeMapAsValue, + } + es.values.Add(addrRange{addr, addr + 1}, oes) + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + + // See above: no ref recording. + ref.Root = wire.Uint(oes.id) + return + } + + // If not a map, then the object must be a pointer. + if obj.Kind() != reflect.Ptr { + Failf("attempt to record non-map and non-pointer object %#v", obj) + } + + obj = obj.Elem() // Value from here. + + // Is this a zero-sized type? + typ := obj.Type() + size := typ.Size() + if size == 0 { + if addr == dummyAddr { + // Zero-sized objects point to a dummy byte within the + // runtime. There's no sense recording this in the + // address map. We add this to the dedicated + // zeroValues. + // + // Note that zero-sized objects must be *true* + // zero-sized objects. They cannot be part of some + // larger object. In that case, they are assigned a + // 1-byte address at the end of the object. + oes, ok := es.zeroValues[typ] + if !ok { + oes = &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + es.zeroValues[typ] = oes + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + } + + // There's also no sense tracking back references. We + // know that this is a true zero-sized object, and not + // part of a larger container, so it will not change. + ref.Root = wire.Uint(oes.id) + return + } + size = 1 // See above. + } + + // Calculate the container. + end := addr + size + r := addrRange{addr, end} + if seg, _ := es.values.Find(addr); seg.Ok() { + existing := seg.Value() + switch { + case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): + // The object is a perfect match. Happy path. Avoid the + // traversal and just return directly. We don't need to + // encode the type information or any dots here. + ref.Root = wire.Uint(existing.id) + existing.refs = append(existing.refs, ref) + return + + case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): + // The previously registered object is larger than + // this, no need to update. But we expect some + // traversal below. + + case seg.Start() == addr && seg.End() == end: + if !isSameSizeParent(obj, existing.obj.Type()) { + break // Needs traversal. + } + fallthrough // Needs update. + + case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): + // Update the object and redo the encoding. + old := existing.obj + existing.obj = obj + es.deferred.Remove(existing) + es.deferred.PushBack(existing) + + // The previously registered object is superseded by + // this new object. We are guaranteed to not have any + // mergeable neighbours in this segment set. + if !raceEnabled { + seg.SetRangeUnchecked(r) + } else { + // Add extra paranoid. This will be statically + // removed at compile time unless a race build. + es.values.Remove(seg) + es.values.Add(r, existing) + seg = es.values.LowerBoundSegment(addr) + } + + // Compute the traversal required & update references. + dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) + wt := es.findType(obj.Type()) + for _, ref := range existing.refs { + ref.Dots = append(ref.Dots, dots...) + ref.Type = wt + } + default: + // There is a non-sensical overlap. + Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + } + + // Compute the new reference, record and return it. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) + ref.Type = es.findType(obj.Type()) + existing.refs = append(existing.refs, ref) + return + } + + // The only remaining case is a pointer value that doesn't overlap with + // any registered addresses. Create a new entry for it, and start + // tracking the first reference we just created. + oes := &objectEncodeState{ + id: es.nextID(), + obj: obj, + } + if !raceEnabled { + es.values.AddWithoutMerging(r, oes) } else { - // Push back the map itself; when maps are encoded from the - // top-level, forceMap will be equal to true. - es.pending.PushBack(queuedObject{id: id, obj: obj, path: es.recoverable.copy()}) + // Merges should never happen. This is just enabled extra + // sanity checks because the Merge function below will panic. + es.values.Add(r, oes) + } + es.pending.PushBack(oes) + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) +} + +// traverse searches for a target object within a root object, where the target +// object is a struct field or array element within root, with potentially +// multiple intervening types. traverse returns the set of field or element +// traversals required to reach the target. +// +// Note that for efficiency, traverse returns the dots in the reverse order. +// That is, the first traversal required will be the last element of the list. +// +// Precondition: The target object must lie completely within the range defined +// by [rootAddr, rootAddr + sizeof(rootType)]. +func traverse(rootType, targetType reflect.Type, rootAddr, targetAddr uintptr) []wire.Dot { + // Recursion base case: the types actually match. + if targetType == rootType && targetAddr == rootAddr { + return nil } - return id + switch rootType.Kind() { + case reflect.Struct: + offset := targetAddr - rootAddr + for i := rootType.NumField(); i > 0; i-- { + field := rootType.Field(i - 1) + // The first field from the end with an offset that is + // smaller than or equal to our address offset is where + // the target is located. Traverse from there. + if field.Offset <= offset { + dots := traverse(field.Type, targetType, rootAddr+field.Offset, targetAddr) + fieldName := wire.FieldName(field.Name) + return append(dots, &fieldName) + } + } + // Should never happen; the target should be reachable. + Failf("no field in root type %v contains target type %v", rootType, targetType) + + case reflect.Array: + // Since arrays have homogenous types, all elements have the + // same size and we can compute where the target lives. This + // does not matter for the purpose of typing, but matters for + // the purpose of computing the address of the given index. + elemSize := int(rootType.Elem().Size()) + n := int(targetAddr-rootAddr) / elemSize // Relies on integer division rounding down. + if rootType.Len() < n { + Failf("traversal target of type %v @%x is beyond the end of the array type %v @%x with %v elements", + targetType, targetAddr, rootType, rootAddr, rootType.Len()) + } + dots := traverse(rootType.Elem(), targetType, rootAddr+uintptr(n*elemSize), targetAddr) + return append(dots, wire.Index(n)) + + default: + // For any other type, there's no possibility of aliasing so if + // the types didn't match earlier then we have an addresss + // collision which shouldn't be possible at this point. + Failf("traverse failed for root type %v and target type %v", rootType, targetType) + } + panic("unreachable") } // encodeMap encodes a map. -func (es *encodeState) encodeMap(obj reflect.Value) *pb.Map { - var ( - keys []*pb.Object - values []*pb.Object - ) +func (es *encodeState) encodeMap(obj reflect.Value, dest *wire.Object) { + if obj.IsNil() { + // Because there is a difference between a nil map and an empty + // map, we need to not decode in the case of a truly nil map. + *dest = wire.Nil{} + return + } + l := obj.Len() + m := &wire.Map{ + Keys: make([]wire.Object, l), + Values: make([]wire.Object, l), + } + *dest = m for i, k := range obj.MapKeys() { v := obj.MapIndex(k) - kp := es.encodeObject(k, false, ".(key %d)", i) - vp := es.encodeObject(v, false, "[%#v]", k.Interface()) - keys = append(keys, kp) - values = append(values, vp) + // Map keys must be encoded using the full value because the + // type will be omitted after the first key. + es.encodeObject(k, encodeAsValue, &m.Keys[i]) + es.encodeObject(v, encodeAsValue, &m.Values[i]) } - return &pb.Map{Keys: keys, Values: values} +} + +// objectEncoder is for encoding structs. +type objectEncoder struct { + // es is encodeState. + es *encodeState + + // encoded is the encoded struct. + encoded *wire.Struct +} + +// save is called by the public methods on Sink. +func (oe *objectEncoder) save(slot int, obj reflect.Value) { + fieldValue := oe.encoded.Field(slot) + oe.es.encodeObject(obj, encodeDefault, fieldValue) } // encodeStruct encodes a composite object. -func (es *encodeState) encodeStruct(obj reflect.Value) *pb.Struct { - // Invoke the save. - m := Map{newInternalMap(es, nil, nil)} - defer internalMapPool.Put(m.internalMap) +func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + // Ensure that the obj is addressable. There are two cases when it is + // not. First, is when this is dispatched via SaveValue. Second, when + // this is a map key as a struct. Either way, we need to make a copy to + // obtain an addressable value. if !obj.CanAddr() { - // Force it to a * type of the above; this involves a copy. localObj := reflect.New(obj.Type()) localObj.Elem().Set(obj) obj = localObj.Elem() } - fns, ok := registeredTypes.lookupFns(obj.Addr().Type()) - if ok { - // Invoke the provided saver. - fns.invokeSave(obj.Addr(), m) - } else if obj.NumField() == 0 { - // Allow unregistered anonymous, empty structs. - return &pb.Struct{} - } else { - // Propagate an error. - panic(fmt.Errorf("unregistered type %T", obj.Interface())) - } - - // Sort the underlying slice, and check for duplicates. This is done - // once instead of on each add, because performing this sort once is - // far more efficient. - if len(m.data) > 1 { - sort.Slice(m.data, func(i, j int) bool { - return m.data[i].name < m.data[j].name - }) - for i := range m.data { - if i > 0 && m.data[i-1].name == m.data[i].name { - panic(fmt.Errorf("duplicate name %s", m.data[i].name)) - } + + // Prepare the value. + s := &wire.Struct{} + *dest = s + + // Look the type up in the database. + te, ok := es.types.Lookup(obj.Type()) + if te == nil { + if obj.NumField() == 0 { + // Allow unregistered anonymous, empty structs. This + // will just return success without ever invoking the + // passed function. This uses the immutable EmptyStruct + // variable to prevent an allocation in this case. + // + // Note that this mechanism does *not* work for + // interfaces in general. So you can't dispatch + // non-registered empty structs via interfaces because + // then they can't be restored. + s.Alloc(0) + return } + // We need a SaverLoader for struct types. + Failf("struct %T does not implement SaverLoader", obj.Interface()) } - - // Encode the resulting fields. - fields := make([]*pb.Field, 0, len(m.data)) - for _, e := range m.data { - fields = append(fields, &pb.Field{ - Name: e.name, - Value: e.object, - }) + if !ok { + // Queue the type to be serialized. + es.pendingTypes = append(es.pendingTypes, te.Type) } - // Return the encoded object. - return &pb.Struct{Fields: fields} + // Invoke the provided saver. + s.TypeID = wire.TypeID(te.ID) + s.Alloc(len(te.Fields)) + oe := objectEncoder{ + es: es, + encoded: s, + } + es.stats.start(te.ID) + defer es.stats.done() + if sl, ok := obj.Addr().Interface().(SaverLoader); ok { + // Note: may be a registered empty struct which does not + // implement the saver/loader interfaces. + sl.StateSave(Sink{internal: oe}) + } } // encodeArray encodes an array. -func (es *encodeState) encodeArray(obj reflect.Value) *pb.Array { - var ( - contents []*pb.Object - ) - for i := 0; i < obj.Len(); i++ { - entry := es.encodeObject(obj.Index(i), false, "[%d]", i) - contents = append(contents, entry) - } - return &pb.Array{Contents: contents} +func (es *encodeState) encodeArray(obj reflect.Value, dest *wire.Object) { + l := obj.Len() + a := &wire.Array{ + Contents: make([]wire.Object, l), + } + *dest = a + for i := 0; i < l; i++ { + // We need to encode the full value because arrays are encoded + // using the type information from only the first element. + es.encodeObject(obj.Index(i), encodeAsValue, &a.Contents[i]) + } +} + +// findType recursively finds type information. +func (es *encodeState) findType(typ reflect.Type) wire.TypeSpec { + // First: check if this is a proper type. It's possible for pointers, + // slices, arrays, maps, etc to all have some different type. + te, ok := es.types.Lookup(typ) + if te != nil { + if !ok { + // See encodeStruct. + es.pendingTypes = append(es.pendingTypes, te.Type) + } + return wire.TypeID(te.ID) + } + + switch typ.Kind() { + case reflect.Ptr: + return &wire.TypeSpecPointer{ + Type: es.findType(typ.Elem()), + } + case reflect.Slice: + return &wire.TypeSpecSlice{ + Type: es.findType(typ.Elem()), + } + case reflect.Array: + return &wire.TypeSpecArray{ + Count: wire.Uint(typ.Len()), + Type: es.findType(typ.Elem()), + } + case reflect.Map: + return &wire.TypeSpecMap{ + Key: es.findType(typ.Key()), + Value: es.findType(typ.Elem()), + } + default: + // After potentially chasing many pointers, the + // ultimate type of the object is not known. + Failf("type %q is not known", typ) + } + panic("unreachable") } // encodeInterface encodes an interface. -// -// Precondition: the value is not nil. -func (es *encodeState) encodeInterface(obj reflect.Value) *pb.Interface { - // Check for the nil interface. - obj = reflect.ValueOf(obj.Interface()) +func (es *encodeState) encodeInterface(obj reflect.Value, dest *wire.Object) { + // Dereference the object. + obj = obj.Elem() if !obj.IsValid() { - return &pb.Interface{ - Type: "", // left alone in decode. - Value: &pb.Object{Value: &pb.Object_RefValue{0}}, + // Special case: the nil object. + *dest = &wire.Interface{ + Type: wire.TypeSpecNil{}, + Value: wire.Nil{}, } + return } - // We have an interface value here. How do we save that? We - // resolve the underlying type and save it as a dispatchable. - typName, ok := registeredTypes.lookupName(obj.Type()) - if !ok { - panic(fmt.Errorf("type %s is not registered", obj.Type())) + + // Encode underlying object. + i := &wire.Interface{ + Type: es.findType(obj.Type()), } + *dest = i + es.encodeObject(obj, encodeAsValue, &i.Value) +} - // Encode the object again. - return &pb.Interface{ - Type: typName, - Value: es.encodeObject(obj, false, ".(%s)", typName), +// isPrimitive returns true if this is a primitive object, or a composite +// object composed entirely of primitives. +func isPrimitiveZero(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Ptr: + // Pointers are always treated as primitive types because we + // won't encode directly from here. Returning true here won't + // prevent the object from being encoded correctly. + return true + case reflect.Bool: + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return true + case reflect.Float32, reflect.Float64: + return true + case reflect.Complex64, reflect.Complex128: + return true + case reflect.String: + return true + case reflect.Slice: + // The slice itself a primitive, but not necessarily the array + // that points to. This is similar to a pointer. + return true + case reflect.Array: + // We cannot treat an array as a primitive, because it may be + // composed of structures or other things with side-effects. + return isPrimitiveZero(typ.Elem()) + case reflect.Interface: + // Since we now that this type is the zero type, the interface + // value must be zero. Therefore this is primitive. + return true + case reflect.Struct: + return false + case reflect.Map: + // The isPrimitiveZero function is called only on zero-types to + // see if it's safe to serialize. Since a zero map has no + // elements, it is safe to treat as a primitive. + return true + default: + Failf("unknown type %q", typ.Name()) } + panic("unreachable") } -// encodeObject encodes an object. -// -// If mapAsValue is true, then a map will be encoded directly. -func (es *encodeState) encodeObject(obj reflect.Value, mapAsValue bool, format string, param interface{}) (object *pb.Object) { - es.push(false, format, param) - es.stats.Add(obj) - es.stats.Start(obj) +// encodeStrategy is the strategy used for encodeObject. +type encodeStrategy int +const ( + // encodeDefault means types are encoded normally as references. + encodeDefault encodeStrategy = iota + + // encodeAsValue means that types will never take short-circuited and + // will always be encoded as a normal value. + encodeAsValue + + // encodeMapAsValue means that even maps will be fully encoded. + encodeMapAsValue +) + +// encodeObject encodes an object. +func (es *encodeState) encodeObject(obj reflect.Value, how encodeStrategy, dest *wire.Object) { + if how == encodeDefault && isPrimitiveZero(obj.Type()) && obj.IsZero() { + *dest = wire.Nil{} + return + } switch obj.Kind() { + case reflect.Ptr: // Fast path: first. + r := new(wire.Ref) + *dest = r + if obj.IsNil() { + // May be in an array or elsewhere such that a value is + // required. So we encode as a reference to the zero + // object, which does not exist. Note that this has to + // be handled correctly in the decode path as well. + return + } + es.resolve(obj, r) case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolValue{obj.Bool()}} + *dest = wire.Bool(obj.Bool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64Value{obj.Int()}} + *dest = wire.Int(obj.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_Uint64Value{obj.Uint()}} - case reflect.Float32, reflect.Float64: - object = &pb.Object{Value: &pb.Object_DoubleValue{obj.Float()}} + *dest = wire.Uint(obj.Uint()) + case reflect.Float32: + *dest = wire.Float32(obj.Float()) + case reflect.Float64: + *dest = wire.Float64(obj.Float()) + case reflect.Complex64: + c := wire.Complex64(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.Complex128: + c := wire.Complex128(obj.Complex()) + *dest = &c // Needs alloc. + case reflect.String: + s := wire.String(obj.String()) + *dest = &s // Needs alloc. case reflect.Array: - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - object = &pb.Object{Value: &pb.Object_ByteArrayValue{pbSlice(obj).Interface().([]byte)}} - case reflect.Uint16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]uint16) - t := make([]uint32, len(s)) - for i := range s { - t[i] = uint32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Uint16ArrayValue{&pb.Uint16S{Values: t}}} - case reflect.Uint32: - object = &pb.Object{Value: &pb.Object_Uint32ArrayValue{&pb.Uint32S{Values: pbSlice(obj).Interface().([]uint32)}}} - case reflect.Uint64: - object = &pb.Object{Value: &pb.Object_Uint64ArrayValue{&pb.Uint64S{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Uintptr: - object = &pb.Object{Value: &pb.Object_UintptrArrayValue{&pb.Uintptrs{Values: pbSlice(obj).Interface().([]uint64)}}} - case reflect.Int8: - object = &pb.Object{Value: &pb.Object_Int8ArrayValue{&pb.Int8S{Values: pbSlice(obj).Interface().([]byte)}}} - case reflect.Int16: - // 16-bit slices are serialized as 32-bit slices. - // See object.proto for details. - s := pbSlice(obj).Interface().([]int16) - t := make([]int32, len(s)) - for i := range s { - t[i] = int32(s[i]) - } - object = &pb.Object{Value: &pb.Object_Int16ArrayValue{&pb.Int16S{Values: t}}} - case reflect.Int32: - object = &pb.Object{Value: &pb.Object_Int32ArrayValue{&pb.Int32S{Values: pbSlice(obj).Interface().([]int32)}}} - case reflect.Int64: - object = &pb.Object{Value: &pb.Object_Int64ArrayValue{&pb.Int64S{Values: pbSlice(obj).Interface().([]int64)}}} - case reflect.Bool: - object = &pb.Object{Value: &pb.Object_BoolArrayValue{&pb.Bools{Values: pbSlice(obj).Interface().([]bool)}}} - case reflect.Float32: - object = &pb.Object{Value: &pb.Object_Float32ArrayValue{&pb.Float32S{Values: pbSlice(obj).Interface().([]float32)}}} - case reflect.Float64: - object = &pb.Object{Value: &pb.Object_Float64ArrayValue{&pb.Float64S{Values: pbSlice(obj).Interface().([]float64)}}} - default: - object = &pb.Object{Value: &pb.Object_ArrayValue{es.encodeArray(obj)}} - } + es.encodeArray(obj, dest) case reflect.Slice: - if obj.IsNil() || obj.Cap() == 0 { - // Handled specially in decode; store as nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - // Serialize a slice as the array plus length and capacity. - object = &pb.Object{Value: &pb.Object_SliceValue{&pb.Slice{ - Capacity: uint32(obj.Cap()), - Length: uint32(obj.Len()), - RefValue: es.register(arrayFromSlice(obj)), - }}} + s := &wire.Slice{ + Capacity: wire.Uint(obj.Cap()), + Length: wire.Uint(obj.Len()), } - case reflect.String: - object = &pb.Object{Value: &pb.Object_StringValue{[]byte(obj.String())}} - case reflect.Ptr: + *dest = s + // Note that we do need to provide a wire.Slice type here as + // how is not encodeDefault. If this were the case, then it + // would have been caught by the IsZero check above and we + // would have just used wire.Nil{}. if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else { - es.push(true /* dereference */, "", nil) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} - es.pop() + return } + // Slices need pointer resolution. + es.resolve(arrayFromSlice(obj), &s.Ref) case reflect.Interface: - // We don't check for IsNil here, as we want to encode type - // information. The case of the empty interface (no type, no - // value) is handled by encodeInteface. - object = &pb.Object{Value: &pb.Object_InterfaceValue{es.encodeInterface(obj)}} + es.encodeInterface(obj, dest) case reflect.Struct: - object = &pb.Object{Value: &pb.Object_StructValue{es.encodeStruct(obj)}} + es.encodeStruct(obj, dest) case reflect.Map: - if obj.IsNil() { - // Handled specially in decode; store as a nil value. - object = &pb.Object{Value: &pb.Object_RefValue{0}} - } else if mapAsValue { - // Encode the map directly. - object = &pb.Object{Value: &pb.Object_MapValue{es.encodeMap(obj)}} - } else { - // Encode a reference to the map. - // - // Remove the map object count here to avoid double - // counting, as this object will be counted again when - // it gets processed later. We do not add a reference - // count as the reference is artificial. - es.stats.Remove(obj) - object = &pb.Object{Value: &pb.Object_RefValue{es.register(obj)}} + if how == encodeMapAsValue { + es.encodeMap(obj, dest) + return } + r := new(wire.Ref) + *dest = r + es.resolve(obj, r) default: - panic(fmt.Errorf("unknown primitive %#v", obj.Interface())) + Failf("unknown object %#v", obj.Interface()) + panic("unreachable") } - - es.stats.Done() - es.pop() - return } -// Serialize serializes the object state. -// -// This function may panic and should be run in safely(). -func (es *encodeState) Serialize(obj reflect.Value) { - es.register(obj.Addr()) - - // Pop off the list until we're done. - for es.pending.Len() > 0 { - e := es.pending.Front() - - // Extract the queued object. - qo := e.Value.(queuedObject) - es.stats.Start(qo.obj) +// Save serializes the object graph rooted at obj. +func (es *encodeState) Save(obj reflect.Value) { + es.stats.init() + defer es.stats.fini(func(id typeID) string { + return es.pendingTypes[id-1].Name + }) + + // Resolve the first object, which should queue a pile of additional + // objects on the pending list. All queued objects should be fully + // resolved, and we should be able to serialize after this call. + var root wire.Ref + es.resolve(obj.Addr(), &root) + + // Encode the graph. + var oes *objectEncodeState + if err := safely(func() { + for oes = es.deferred.Front(); oes != nil; oes = es.deferred.Front() { + // Remove and encode the object. Note that as a result + // of this encoding, the object may be enqueued on the + // deferred list yet again. That's expected, and why it + // is removed first. + es.deferred.Remove(oes) + es.encodeObject(oes.obj, oes.how, &oes.encoded) + } + }); err != nil { + // Include the object in the error message. + Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) + } - es.pending.Remove(e) + // Check that items are pending. + if es.pending.Front() == nil { + Failf("pending is empty?") + } - es.from = &qo.path - o := es.encodeObject(qo.obj, true, "", nil) + // Write the header with the number of objects. Note that there is no + // way that es.lastID could conflict with objectID, which would + // indicate that an impossibly large encoding. + if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + Failf("error writing header: %w", err) + } - // Emit to our output stream. - if err := es.writeObject(qo.id, o); err != nil { - panic(err) + // Serialize all pending types and pending objects. Note that we don't + // bother removing from this list as we walk it because that just + // wastes time. It will not change after this point. + var id objectID + if err := safely(func() { + for _, wt := range es.pendingTypes { + // Encode the type. + wire.Save(es.w, &wt) } + for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { + id++ // First object is 1. + if oes.id != id { + Failf("expected id %d, got %d", id, oes.id) + } - // Mark as done. - es.done.PushBack(e) - es.stats.Done() + // Marshall the object. + wire.Save(es.w, oes.encoded) + } + }); err != nil { + // Include the object and the error. + Failf("error serializing object %#v: %w", oes.encoded, err) } - // Write a zero-length terminal at the end; this is a sanity check - // applied at decode time as well (see decode.go). - if err := WriteHeader(es.w, 0, false); err != nil { - panic(err) + // Check what we wrote. + if id != es.lastID { + Failf("expected %d objects, wrote %d", es.lastID, id) } } +// objectFlag indicates that the length is a # of objects, rather than a raw +// byte length. When this is set on a length header in the stream, it may be +// decoded appropriately. +const objectFlag uint64 = 1 << 63 + // WriteHeader writes a header. // // Each object written to the statefile should be prefixed with a header. In // order to generate statefiles that play nicely with debugging tools, raw // writes should be prefixed with a header with object set to false and the // appropriate length. This will allow tools to skip these regions. -func WriteHeader(w io.Writer, length uint64, object bool) error { - // The lowest-order bit encodes whether this is a valid object. This is - // a purely internal convention, but allows the object flag to be - // returned from ReadHeader. - length = length << 1 +func WriteHeader(w wire.Writer, length uint64, object bool) error { + // Sanity check the length. + if length&objectFlag != 0 { + Failf("impossibly huge length: %d", length) + } if object { - length |= 0x1 + length |= objectFlag } // Write a header. - var hdr [32]byte - encodedLen := binary.PutUvarint(hdr[:], length) - for done := 0; done < encodedLen; { - n, err := w.Write(hdr[done:encodedLen]) - done += n - if n == 0 && err != nil { - return err - } - } - - return nil + return safely(func() { + wire.SaveUint(w, length) + }) } -// writeObject writes an object to the stream. -func (es *encodeState) writeObject(id uint64, obj *pb.Object) error { - // Marshal the proto. - buf, err := proto.Marshal(obj) - if err != nil { - return err - } +// pendingMapper is for the pending list. +type pendingMapper struct{} - // Write the object header. - if err := WriteHeader(es.w, uint64(len(buf)), true); err != nil { - return err - } +func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // Write the object. - for done := 0; done < len(buf); { - n, err := es.w.Write(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } +// deferredMapper is for the deferred list. +type deferredMapper struct{} - return nil -} +func (deferredMapper) linkerFor(oes *objectEncodeState) *deferredEntry { return &oes.deferredEntry } // addrSetFunctions is used by addrSet. type addrSetFunctions struct{} @@ -458,13 +818,24 @@ func (addrSetFunctions) MaxKey() uintptr { return ^uintptr(0) } -func (addrSetFunctions) ClearValue(val *reflect.Value) { +func (addrSetFunctions) ClearValue(val **objectEncodeState) { + *val = nil } -func (addrSetFunctions) Merge(_ addrRange, val1 reflect.Value, _ addrRange, val2 reflect.Value) (reflect.Value, bool) { - return val1, val1 == val2 +func (addrSetFunctions) Merge(r1 addrRange, val1 *objectEncodeState, r2 addrRange, val2 *objectEncodeState) (*objectEncodeState, bool) { + if val1.obj == val2.obj { + // This, should never happen. It would indicate that the same + // object exists in two non-contiguous address ranges. Note + // that this assertion can only be triggered if the race + // detector is enabled. + Failf("unexpected merge in addrSet @ %v and %v: %#v and %#v", r1, r2, val1.obj, val2.obj) + } + // Reject the merge. + return val1, false } -func (addrSetFunctions) Split(_ addrRange, val reflect.Value, _ uintptr) (reflect.Value, reflect.Value) { - return val, val +func (addrSetFunctions) Split(r addrRange, val *objectEncodeState, _ uintptr) (*objectEncodeState, *objectEncodeState) { + // A split should never happen: we don't remove ranges. + Failf("unexpected split in addrSet @ %v: %#v", r, val.obj) + panic("unreachable") } diff --git a/pkg/state/encode_unsafe.go b/pkg/state/encode_unsafe.go index 457e6dbb7..e0dad83b4 100644 --- a/pkg/state/encode_unsafe.go +++ b/pkg/state/encode_unsafe.go @@ -31,51 +31,3 @@ func arrayFromSlice(obj reflect.Value) reflect.Value { reflect.ArrayOf(obj.Cap(), obj.Type().Elem()), unsafe.Pointer(obj.Pointer())) } - -// pbSlice returns a protobuf-supported slice of the array and erase the -// original element type (which could be a defined type or non-supported type). -func pbSlice(obj reflect.Value) reflect.Value { - var typ reflect.Type - switch obj.Type().Elem().Kind() { - case reflect.Uint8: - typ = reflect.TypeOf(byte(0)) - case reflect.Uint16: - typ = reflect.TypeOf(uint16(0)) - case reflect.Uint32: - typ = reflect.TypeOf(uint32(0)) - case reflect.Uint64: - typ = reflect.TypeOf(uint64(0)) - case reflect.Uintptr: - typ = reflect.TypeOf(uint64(0)) - case reflect.Int8: - typ = reflect.TypeOf(byte(0)) - case reflect.Int16: - typ = reflect.TypeOf(int16(0)) - case reflect.Int32: - typ = reflect.TypeOf(int32(0)) - case reflect.Int64: - typ = reflect.TypeOf(int64(0)) - case reflect.Bool: - typ = reflect.TypeOf(bool(false)) - case reflect.Float32: - typ = reflect.TypeOf(float32(0)) - case reflect.Float64: - typ = reflect.TypeOf(float64(0)) - default: - panic("slice element is not of basic value type") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), typ), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem().Slice(0, obj.Len()) -} - -func castSlice(obj reflect.Value, elemTyp reflect.Type) reflect.Value { - if obj.Type().Elem().Size() != elemTyp.Size() { - panic("cannot cast slice into other element type of different size") - } - return reflect.NewAt( - reflect.ArrayOf(obj.Len(), elemTyp), - unsafe.Pointer(obj.Slice(0, obj.Len()).Pointer()), - ).Elem() -} diff --git a/pkg/state/map.go b/pkg/state/map.go deleted file mode 100644 index 4f3ebb0da..000000000 --- a/pkg/state/map.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "context" - "fmt" - "reflect" - "sort" - "sync" - - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// entry is a single map entry. -type entry struct { - name string - object *pb.Object -} - -// internalMap is the internal Map state. -// -// These are recycled via a pool to avoid churn. -type internalMap struct { - // es is encodeState. - es *encodeState - - // ds is decodeState. - ds *decodeState - - // os is current object being decoded. - // - // This will always be nil during encode. - os *objectState - - // data stores the encoded values. - data []entry -} - -var internalMapPool = sync.Pool{ - New: func() interface{} { - return new(internalMap) - }, -} - -// newInternalMap returns a cached map. -func newInternalMap(es *encodeState, ds *decodeState, os *objectState) *internalMap { - m := internalMapPool.Get().(*internalMap) - m.es = es - m.ds = ds - m.os = os - if m.data != nil { - m.data = m.data[:0] - } - return m -} - -// Map is a generic state container. -// -// This is the object passed to Save and Load in order to store their state. -// -// Detailed documentation is available in individual methods. -type Map struct { - *internalMap -} - -// Save adds the given object to the map. -// -// You should pass always pointers to the object you are saving. For example: -// -// type X struct { -// A int -// B *int -// } -// -// func (x *X) Save(m Map) { -// m.Save("A", &x.A) -// m.Save("B", &x.B) -// } -// -// func (x *X) Load(m Map) { -// m.Load("A", &x.A) -// m.Load("B", &x.B) -// } -func (m Map) Save(name string, objPtr interface{}) { - m.save(name, reflect.ValueOf(objPtr).Elem(), ".%s") -} - -// SaveValue adds the given object value to the map. -// -// This should be used for values where pointers are not available, or casts -// are required during Save/Load. -// -// For example, if we want to cast external package type P.Foo to int64: -// -// type X struct { -// A P.Foo -// } -// -// func (x *X) Save(m Map) { -// m.SaveValue("A", int64(x.A)) -// } -// -// func (x *X) Load(m Map) { -// m.LoadValue("A", new(int64), func(x interface{}) { -// x.A = P.Foo(x.(int64)) -// }) -// } -func (m Map) SaveValue(name string, obj interface{}) { - m.save(name, reflect.ValueOf(obj), ".(value %s)") -} - -// save is helper for the above. It takes the name of value to save the field -// to, the field object (obj), and a format string that specifies how the -// field's saving logic is dispatched from the struct (normal, value, etc.). The -// format string should expect one string parameter, which is the name of the -// field. -func (m Map) save(name string, obj reflect.Value, format string) { - if m.es == nil { - // Not currently encoding. - m.Failf("no encode state for %q", name) - } - - // Attempt the encode. - // - // These are sorted at the end, after all objects are added and will be - // sorted and checked for duplicates (see encodeStruct). - m.data = append(m.data, entry{ - name: name, - object: m.es.encodeObject(obj, false, format, name), - }) -} - -// Load loads the given object from the map. -// -// See Save for an example. -func (m Map) Load(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), false, nil, ".%s") -} - -// LoadWait loads the given objects from the map, and marks it as requiring all -// AfterLoad executions to complete prior to running this object's AfterLoad. -// -// See Save for an example. -func (m Map) LoadWait(name string, objPtr interface{}) { - m.load(name, reflect.ValueOf(objPtr), true, nil, ".(wait %s)") -} - -// LoadValue loads the given object value from the map. -// -// See SaveValue for an example. -func (m Map) LoadValue(name string, objPtr interface{}, fn func(interface{})) { - o := reflect.ValueOf(objPtr) - m.load(name, o, true, func() { fn(o.Elem().Interface()) }, ".(value %s)") -} - -// load is helper for the above. It takes the name of value to load the field -// from, the target field pointer (objPtr), whether load completion of the -// struct depends on the field's load completion (wait), the load completion -// logic (fn), and a format string that specifies how the field's loading logic -// is dispatched from the struct (normal, wait, value, etc.). The format string -// should expect one string parameter, which is the name of the field. -func (m Map) load(name string, objPtr reflect.Value, wait bool, fn func(), format string) { - if m.ds == nil { - // Not currently decoding. - m.Failf("no decode state for %q", name) - } - - // Find the object. - // - // These are sorted up front (and should appear in the state file - // sorted as well), so we can do a binary search here to ensure that - // large structs don't behave badly. - i := sort.Search(len(m.data), func(i int) bool { - return m.data[i].name >= name - }) - if i >= len(m.data) || m.data[i].name != name { - // There is no data for this name? - m.Failf("no data found for %q", name) - } - - // Perform the decode. - m.ds.decodeObject(m.os, objPtr.Elem(), m.data[i].object, format, name) - if wait { - // Mark this individual object a blocker. - m.ds.waitObject(m.os, m.data[i].object, fn) - } -} - -// Failf fails the save or restore with the provided message. Processing will -// stop after calling Failf, as the state package uses a panic & recover -// mechanism for state errors. You should defer any cleanup required. -func (m Map) Failf(format string, args ...interface{}) { - panic(fmt.Errorf(format, args...)) -} - -// AfterLoad schedules a function execution when all objects have been allocated -// and their automated loading and customized load logic have been executed. fn -// will not be executed until all of current object's dependencies' AfterLoad() -// logic, if exist, have been executed. -func (m Map) AfterLoad(fn func()) { - if m.ds == nil { - // Not currently decoding. - m.Failf("not decoding") - } - - // Queue the local callback; this will execute when all of the above - // data dependencies have been cleared. - m.os.callbacks = append(m.os.callbacks, fn) -} - -// Context returns the current context object. -func (m Map) Context() context.Context { - if m.es != nil { - return m.es.ctx - } else if m.ds != nil { - return m.ds.ctx - } - return context.Background() // No context. -} diff --git a/pkg/state/object.proto b/pkg/state/object.proto deleted file mode 100644 index 5ebcfb151..000000000 --- a/pkg/state/object.proto +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -package gvisor.state.statefile; - -// Slice is a slice value. -message Slice { - uint32 length = 1; - uint32 capacity = 2; - uint64 ref_value = 3; -} - -// Array is an array value. -message Array { - repeated Object contents = 1; -} - -// Map is a map value. -message Map { - repeated Object keys = 1; - repeated Object values = 2; -} - -// Interface is an interface value. -message Interface { - string type = 1; - Object value = 2; -} - -// Struct is a basic composite value. -message Struct { - repeated Field fields = 1; -} - -// Field encodes a single field. -message Field { - string name = 1; - Object value = 2; -} - -// Uint16s encodes an uint16 array. To be used inside oneof structure. -message Uint16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated uint32 values = 1; -} - -// Uint32s encodes an uint32 array. To be used inside oneof structure. -message Uint32s { - repeated fixed32 values = 1; -} - -// Uint64s encodes an uint64 array. To be used inside oneof structure. -message Uint64s { - repeated fixed64 values = 1; -} - -// Uintptrs encodes an uintptr array. To be used inside oneof structure. -message Uintptrs { - repeated fixed64 values = 1; -} - -// Int8s encodes an int8 array. To be used inside oneof structure. -message Int8s { - bytes values = 1; -} - -// Int16s encodes an int16 array. To be used inside oneof structure. -message Int16s { - // There is no 16-bit type in protobuf so we use variable length 32-bit here. - repeated int32 values = 1; -} - -// Int32s encodes an int32 array. To be used inside oneof structure. -message Int32s { - repeated sfixed32 values = 1; -} - -// Int64s encodes an int64 array. To be used inside oneof structure. -message Int64s { - repeated sfixed64 values = 1; -} - -// Bools encodes a boolean array. To be used inside oneof structure. -message Bools { - repeated bool values = 1; -} - -// Float64s encodes a float64 array. To be used inside oneof structure. -message Float64s { - repeated double values = 1; -} - -// Float32s encodes a float32 array. To be used inside oneof structure. -message Float32s { - repeated float values = 1; -} - -// Object are primitive encodings. -// -// Note that ref_value references an Object.id, below. -message Object { - oneof value { - bool bool_value = 1; - bytes string_value = 2; - int64 int64_value = 3; - uint64 uint64_value = 4; - double double_value = 5; - uint64 ref_value = 6; - Slice slice_value = 7; - Array array_value = 8; - Interface interface_value = 9; - Struct struct_value = 10; - Map map_value = 11; - bytes byte_array_value = 12; - Uint16s uint16_array_value = 13; - Uint32s uint32_array_value = 14; - Uint64s uint64_array_value = 15; - Uintptrs uintptr_array_value = 16; - Int8s int8_array_value = 17; - Int16s int16_array_value = 18; - Int32s int32_array_value = 19; - Int64s int64_array_value = 20; - Bools bool_array_value = 21; - Float64s float64_array_value = 22; - Float32s float32_array_value = 23; - } -} diff --git a/pkg/state/pretty/BUILD b/pkg/state/pretty/BUILD new file mode 100644 index 000000000..d053802f7 --- /dev/null +++ b/pkg/state/pretty/BUILD @@ -0,0 +1,13 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pretty", + srcs = ["pretty.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go new file mode 100644 index 000000000..cf37aaa49 --- /dev/null +++ b/pkg/state/pretty/pretty.go @@ -0,0 +1,273 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pretty is a pretty-printer for state streams. +package pretty + +import ( + "fmt" + "io" + "io/ioutil" + "reflect" + "strings" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +func formatRef(x *wire.Ref, graph uint64, html bool) string { + baseRef := fmt.Sprintf("g%dr%d", graph, x.Root) + fullRef := baseRef + if len(x.Dots) > 0 { + // See wire.Ref; Type valid if Dots non-zero. + typ, _ := formatType(x.Type, graph, html) + var buf strings.Builder + buf.WriteString("(*") + buf.WriteString(typ) + buf.WriteString(")(") + buf.WriteString(baseRef) + for _, component := range x.Dots { + switch v := component.(type) { + case *wire.FieldName: + buf.WriteString(".") + buf.WriteString(string(*v)) + case wire.Index: + buf.WriteString(fmt.Sprintf("[%d]", v)) + default: + panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) + } + } + buf.WriteString(")") + fullRef = buf.String() + } + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef) + } + return fullRef +} + +func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { + switch x := t.(type) { + case wire.TypeID: + base := fmt.Sprintf("g%dt%d", graph, x) + if html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true + } + return fmt.Sprintf("%s", base), true + case wire.TypeSpecNil: + return "", false // Only nil type. + case *wire.TypeSpecPointer: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("(*%s)", element), true + case *wire.TypeSpecArray: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("[%d](%s)", x.Count, element), true + case *wire.TypeSpecSlice: + element, _ := formatType(x.Type, graph, html) + return fmt.Sprintf("([]%s)", element), true + case *wire.TypeSpecMap: + key, _ := formatType(x.Key, graph, html) + value, _ := formatType(x.Value, graph, html) + return fmt.Sprintf("(map[%s]%s)", key, value), true + default: + panic(fmt.Sprintf("unreachable: unknown type %T", t)) + } +} + +// format formats a single object, for pretty-printing. It also returns whether +// the value is a non-zero value. +func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) { + switch x := encoded.(type) { + case wire.Nil: + return "nil", false + case *wire.String: + return fmt.Sprintf("%q", *x), *x != "" + case *wire.Complex64: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Complex128: + return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 + case *wire.Ref: + return formatRef(x, graph, html), x.Root != 0 + case *wire.Type: + tabs := "\n" + strings.Repeat("\t", depth) + items := make([]string, 0, len(x.Fields)+2) + items = append(items, fmt.Sprintf("type %s {", x.Name)) + for i := 0; i < len(x.Fields); i++ { + items = append(items, fmt.Sprintf("\t%d: %s,", i, x.Fields[i])) + } + items = append(items, "}") + return strings.Join(items, tabs), true // No zero value. + case *wire.Slice: + return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0 + case *wire.Array: + if len(x.Contents) == 0 { + return "[]", false + } + items := make([]string, 0, len(x.Contents)+2) + zeros := make([]string, 0) // used to eliminate zero entries. + items = append(items, "[") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Contents); i++ { + item, ok := format(graph, depth+1, x.Contents[i], html) + if !ok { + zeros = append(zeros, fmt.Sprintf("\t%s,", item)) + continue + } + if len(zeros) > 0 { + items = append(items, zeros...) + zeros = nil + } + items = append(items, fmt.Sprintf("\t%s,", item)) + } + if len(zeros) > 0 { + items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) + } + items = append(items, "]") + return strings.Join(items, tabs), len(zeros) < len(x.Contents) + case *wire.Struct: + typ, _ := formatType(x.TypeID, graph, html) + if x.Fields() == 0 { + return fmt.Sprintf("struct[%s]{}", typ), false + } + items := make([]string, 0, 2) + items = append(items, fmt.Sprintf("struct[%s]{", typ)) + tabs := "\n" + strings.Repeat("\t", depth) + allZero := true + for i := 0; i < x.Fields(); i++ { + element, ok := format(graph, depth+1, *x.Field(i), html) + allZero = allZero && !ok + items = append(items, fmt.Sprintf("\t%d: %s,", i, element)) + i++ + } + items = append(items, "}") + return strings.Join(items, tabs), !allZero + case *wire.Map: + if len(x.Keys) == 0 { + return "map{}", false + } + items := make([]string, 0, len(x.Keys)+2) + items = append(items, "map{") + tabs := "\n" + strings.Repeat("\t", depth) + for i := 0; i < len(x.Keys); i++ { + key, _ := format(graph, depth+1, x.Keys[i], html) + value, _ := format(graph, depth+1, x.Values[i], html) + items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) + } + items = append(items, "}") + return strings.Join(items, tabs), true + case *wire.Interface: + typ, typOk := formatType(x.Type, graph, html) + element, elementOk := format(graph, depth+1, x.Value, html) + return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk + default: + // Must be a primitive; use reflection. + return fmt.Sprintf("%v", encoded), true + } +} + +// printStream is the basic print implementation. +func printStream(w io.Writer, r wire.Reader, html bool) (err error) { + // current graph ID. + var graph uint64 + + if html { + fmt.Fprintf(w, "<pre>") + defer fmt.Fprintf(w, "</pre>") + } + + defer func() { + if r := recover(); r != nil { + if rErr, ok := r.(error); ok { + err = rErr // Override return. + return + } + panic(r) // Propagate. + } + }() + + for { + // Find the first object to begin generation. + length, object, err := state.ReadHeader(r) + if err == io.EOF { + // Nothing else to do. + break + } else if err != nil { + return err + } + if !object { + graph++ // Increment the graph. + if length > 0 { + fmt.Fprintf(w, "(%d bytes non-object data)\n", length) + io.Copy(ioutil.Discard, &io.LimitedReader{ + R: r, + N: int64(length), + }) + } + continue + } + + // Read & unmarshal the object. + // + // Note that this loop must match the general structure of the + // loop in decode.go. But we don't register type information, + // etc. and just print the raw structures. + var ( + oid uint64 = 1 + tid uint64 = 1 + ) + for oid <= length { + // Unmarshal the object. + encoded := wire.Load(r) + + // Is this a type? + if _, ok := encoded.(*wire.Type); ok { + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dt%d", graph, tid) + if html { + // See below. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + tid++ + continue + } + + // Format the node. + str, _ := format(graph, 0, encoded, html) + tag := fmt.Sprintf("g%dr%d", graph, oid) + if html { + // Create a little tag with an anchor next to it for linking. + tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) + } + if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { + return err + } + oid++ + } + } + + return nil +} + +// PrintText reads the stream from r and prints text to w. +func PrintText(w io.Writer, r wire.Reader) error { + return printStream(w, r, false /* html */) +} + +// PrintHTML reads the stream from r and prints html to w. +func PrintHTML(w io.Writer, r wire.Reader) error { + return printStream(w, r, true /* html */) +} diff --git a/pkg/state/printer.go b/pkg/state/printer.go deleted file mode 100644 index 3ce18242f..000000000 --- a/pkg/state/printer.go +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "fmt" - "io" - "io/ioutil" - "reflect" - "strings" - - "github.com/golang/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" -) - -// format formats a single object, for pretty-printing. It also returns whether -// the value is a non-zero value. -func format(graph uint64, depth int, object *pb.Object, html bool) (string, bool) { - switch x := object.GetValue().(type) { - case *pb.Object_BoolValue: - return fmt.Sprintf("%t", x.BoolValue), x.BoolValue != false - case *pb.Object_StringValue: - return fmt.Sprintf("\"%s\"", string(x.StringValue)), len(x.StringValue) != 0 - case *pb.Object_Int64Value: - return fmt.Sprintf("%d", x.Int64Value), x.Int64Value != 0 - case *pb.Object_Uint64Value: - return fmt.Sprintf("%du", x.Uint64Value), x.Uint64Value != 0 - case *pb.Object_DoubleValue: - return fmt.Sprintf("%f", x.DoubleValue), x.DoubleValue != 0.0 - case *pb.Object_RefValue: - if x.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return ref, true - case *pb.Object_SliceValue: - if x.SliceValue.RefValue == 0 { - return "nil", false - } - ref := fmt.Sprintf("g%dr%d", graph, x.SliceValue.RefValue) - if html { - ref = fmt.Sprintf("<a href=#%s>%s</a>", ref, ref) - } - return fmt.Sprintf("%s[:%d:%d]", ref, x.SliceValue.Length, x.SliceValue.Capacity), true - case *pb.Object_ArrayValue: - if len(x.ArrayValue.Contents) == 0 { - return "[]", false - } - items := make([]string, 0, len(x.ArrayValue.Contents)+2) - zeros := make([]string, 0) // used to eliminate zero entries. - items = append(items, "[") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.ArrayValue.Contents); i++ { - item, ok := format(graph, depth+1, x.ArrayValue.Contents[i], html) - if ok { - if len(zeros) > 0 { - items = append(items, zeros...) - zeros = nil - } - items = append(items, fmt.Sprintf("\t%s,", item)) - } else { - zeros = append(zeros, fmt.Sprintf("\t%s,", item)) - } - } - if len(zeros) > 0 { - items = append(items, fmt.Sprintf("\t... (%d zeros),", len(zeros))) - } - items = append(items, "]") - return strings.Join(items, tabs), len(zeros) < len(x.ArrayValue.Contents) - case *pb.Object_StructValue: - if len(x.StructValue.Fields) == 0 { - return "struct{}", false - } - items := make([]string, 0, len(x.StructValue.Fields)+2) - items = append(items, "struct{") - tabs := "\n" + strings.Repeat("\t", depth) - allZero := true - for _, field := range x.StructValue.Fields { - element, ok := format(graph, depth+1, field.Value, html) - allZero = allZero && !ok - items = append(items, fmt.Sprintf("\t%s: %s,", field.Name, element)) - } - items = append(items, "}") - return strings.Join(items, tabs), !allZero - case *pb.Object_MapValue: - if len(x.MapValue.Keys) == 0 { - return "map{}", false - } - items := make([]string, 0, len(x.MapValue.Keys)+2) - items = append(items, "map{") - tabs := "\n" + strings.Repeat("\t", depth) - for i := 0; i < len(x.MapValue.Keys); i++ { - key, _ := format(graph, depth+1, x.MapValue.Keys[i], html) - value, _ := format(graph, depth+1, x.MapValue.Values[i], html) - items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) - } - items = append(items, "}") - return strings.Join(items, tabs), true - case *pb.Object_InterfaceValue: - if x.InterfaceValue.Type == "" { - return "interface(nil){}", false - } - element, _ := format(graph, depth+1, x.InterfaceValue.Value, html) - return fmt.Sprintf("interface(\"%s\"){%s}", x.InterfaceValue.Type, element), true - case *pb.Object_ByteArrayValue: - return printArray(reflect.ValueOf(x.ByteArrayValue)) - case *pb.Object_Uint16ArrayValue: - return printArray(reflect.ValueOf(x.Uint16ArrayValue.Values)) - case *pb.Object_Uint32ArrayValue: - return printArray(reflect.ValueOf(x.Uint32ArrayValue.Values)) - case *pb.Object_Uint64ArrayValue: - return printArray(reflect.ValueOf(x.Uint64ArrayValue.Values)) - case *pb.Object_UintptrArrayValue: - return printArray(castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0)))) - case *pb.Object_Int8ArrayValue: - return printArray(castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0)))) - case *pb.Object_Int16ArrayValue: - return printArray(reflect.ValueOf(x.Int16ArrayValue.Values)) - case *pb.Object_Int32ArrayValue: - return printArray(reflect.ValueOf(x.Int32ArrayValue.Values)) - case *pb.Object_Int64ArrayValue: - return printArray(reflect.ValueOf(x.Int64ArrayValue.Values)) - case *pb.Object_BoolArrayValue: - return printArray(reflect.ValueOf(x.BoolArrayValue.Values)) - case *pb.Object_Float64ArrayValue: - return printArray(reflect.ValueOf(x.Float64ArrayValue.Values)) - case *pb.Object_Float32ArrayValue: - return printArray(reflect.ValueOf(x.Float32ArrayValue.Values)) - } - - // Should not happen, but tolerate. - return fmt.Sprintf("(unknown proto type: %T)", object.GetValue()), true -} - -// PrettyPrint reads the state stream from r, and pretty prints to w. -func PrettyPrint(w io.Writer, r io.Reader, html bool) error { - var ( - // current graph ID. - graph uint64 - - // current object ID. - id uint64 - ) - - if html { - fmt.Fprintf(w, "<pre>") - defer fmt.Fprintf(w, "</pre>") - } - - for { - // Find the first object to begin generation. - length, object, err := ReadHeader(r) - if err == io.EOF { - // Nothing else to do. - break - } else if err != nil { - return err - } - if !object { - // Increment the graph number & reset the ID. - graph++ - id = 0 - if length > 0 { - fmt.Fprintf(w, "(%d bytes non-object data)\n", length) - io.Copy(ioutil.Discard, &io.LimitedReader{ - R: r, - N: int64(length), - }) - } - continue - } - - // Read & unmarshal the object. - buf := make([]byte, length) - for done := 0; done < len(buf); { - n, err := r.Read(buf[done:]) - done += n - if n == 0 && err != nil { - return err - } - } - obj := new(pb.Object) - if err := proto.Unmarshal(buf, obj); err != nil { - return err - } - - id++ // First object must be one. - str, _ := format(graph, 0, obj, html) - tag := fmt.Sprintf("g%dr%d", graph, id) - if html { - tag = fmt.Sprintf("<a name=%s>%s</a>", tag, tag) - } - if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { - return err - } - } - - return nil -} - -func printArray(s reflect.Value) (string, bool) { - zero := reflect.Zero(s.Type().Elem()).Interface() - z := "0" - switch s.Type().Elem().Kind() { - case reflect.Bool: - z = "false" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - case reflect.Float32, reflect.Float64: - default: - return fmt.Sprintf("unexpected non-primitive type array: %#v", s.Interface()), true - } - - zeros := 0 - items := make([]string, 0, s.Len()) - for i := 0; i <= s.Len(); i++ { - if i < s.Len() && reflect.DeepEqual(s.Index(i).Interface(), zero) { - zeros++ - continue - } - if zeros > 0 { - if zeros <= 4 { - for ; zeros > 0; zeros-- { - items = append(items, z) - } - } else { - items = append(items, fmt.Sprintf("(%d %ss)", zeros, z)) - zeros = 0 - } - } - if i < s.Len() { - items = append(items, fmt.Sprintf("%v", s.Index(i).Interface())) - } - } - return "[" + strings.Join(items, ",") + "]", zeros < s.Len() -} diff --git a/pkg/state/state.go b/pkg/state/state.go index 03ae2dbb0..acb629969 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -31,210 +31,226 @@ // Uint64 default // Float32 default // Float64 default -// Complex64 custom -// Complex128 custom +// Complex64 default +// Complex128 default // Array default // Chan custom // Func custom -// Interface custom -// Map default (*) +// Interface default +// Map default // Ptr default // Slice default // String default -// Struct custom +// Struct custom (*) Unless zero-sized. // UnsafePointer custom // -// (*) Maps are treated as value types by this package, even if they are -// pointers internally. If you want to save two independent references -// to the same map value, you must explicitly use a pointer to a map. +// See README.md for an overview of how encoding and decoding works. package state import ( "context" "fmt" - "io" "reflect" "runtime" - pb "gvisor.dev/gvisor/pkg/state/object_go_proto" + "gvisor.dev/gvisor/pkg/state/wire" ) +// objectID is a unique identifier assigned to each object to be serialized. +// Each instance of an object is considered separately, i.e. if there are two +// objects of the same type in the object graph being serialized, they'll be +// assigned unique objectIDs. +type objectID uint32 + +// typeID is the identifier for a type. Types are serialized and tracked +// alongside objects in order to avoid the overhead of encoding field names in +// all objects. +type typeID uint32 + // ErrState is returned when an error is encountered during encode/decode. type ErrState struct { // err is the underlying error. err error - // path is the visit path from root to the current object. - path string - // trace is the stack trace. trace string } // Error returns a sensible description of the state error. func (e *ErrState) Error() string { - return fmt.Sprintf("%v:\nstate path: %s\n%s", e.err, e.path, e.trace) + return fmt.Sprintf("%v:\n%s", e.err, e.trace) } -// UnwrapErrState returns the underlying error in ErrState. -// -// If err is not *ErrState, err is returned directly. -func UnwrapErrState(err error) error { - if e, ok := err.(*ErrState); ok { - return e.err - } - return err +// Unwrap implements standard unwrapping. +func (e *ErrState) Unwrap() error { + return e.err } // Save saves the given object state. -func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. - es := &encodeState{ - ctx: ctx, - idsByObject: make(map[uintptr]uint64), - w: w, - stats: stats, + es := encodeState{ + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), } // Perform the encoding. - return es.safely(func() { - es.Serialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + es.Save(reflect.ValueOf(rootPtr).Elem()) }) + return es.stats, err } // Load loads a checkpoint. -func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r wire.Reader, rootPtr interface{}) (Stats, error) { // Create the decoding state. - ds := &decodeState{ - ctx: ctx, - objectsByID: make(map[uint64]*objectState), - deferred: make(map[uint64]*pb.Object), - r: r, - stats: stats, + ds := decodeState{ + ctx: ctx, + r: r, + types: makeTypeDecodeDatabase(), + deferred: make(map[objectID]wire.Object), } // Attempt our decode. - return ds.safely(func() { - ds.Deserialize(reflect.ValueOf(rootPtr).Elem()) + err := safely(func() { + ds.Load(reflect.ValueOf(rootPtr).Elem()) }) + return ds.stats, err } -// Fns are the state dispatch functions. -type Fns struct { - // Save is a function like Save(concreteType, Map). - Save interface{} - - // Load is a function like Load(concreteType, Map). - Load interface{} +// Sink is used for Type.StateSave. +type Sink struct { + internal objectEncoder } -// Save executes the save function. -func (fns *Fns) invokeSave(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Save).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// Save adds the given object to the map. +// +// You should pass always pointers to the object you are saving. For example: +// +// type X struct { +// A int +// B *int +// } +// +// func (x *X) StateTypeInfo(m Sink) state.TypeInfo { +// return state.TypeInfo{ +// Name: "pkg.X", +// Fields: []string{ +// "A", +// "B", +// }, +// } +// } +// +// func (x *X) StateSave(m Sink) { +// m.Save(0, &x.A) // Field is A. +// m.Save(1, &x.B) // Field is B. +// } +// +// func (x *X) StateLoad(m Source) { +// m.Load(0, &x.A) // Field is A. +// m.Load(1, &x.B) // Field is B. +// } +func (s Sink) Save(slot int, objPtr interface{}) { + s.internal.save(slot, reflect.ValueOf(objPtr).Elem()) } -// Load executes the load function. -func (fns *Fns) invokeLoad(obj reflect.Value, m Map) { - reflect.ValueOf(fns.Load).Call([]reflect.Value{obj, reflect.ValueOf(m)}) +// SaveValue adds the given object value to the map. +// +// This should be used for values where pointers are not available, or casts +// are required during Save/Load. +// +// For example, if we want to cast external package type P.Foo to int64: +// +// func (x *X) StateSave(m Sink) { +// m.SaveValue(0, "A", int64(x.A)) +// } +// +// func (x *X) StateLoad(m Source) { +// m.LoadValue(0, new(int64), func(x interface{}) { +// x.A = P.Foo(x.(int64)) +// }) +// } +func (s Sink) SaveValue(slot int, obj interface{}) { + s.internal.save(slot, reflect.ValueOf(obj)) } -// validateStateFn ensures types are correct. -func validateStateFn(fn interface{}, typ reflect.Type) bool { - fnTyp := reflect.TypeOf(fn) - if fnTyp.Kind() != reflect.Func { - return false - } - if fnTyp.NumIn() != 2 { - return false - } - if fnTyp.NumOut() != 0 { - return false - } - if fnTyp.In(0) != typ { - return false - } - if fnTyp.In(1) != reflect.TypeOf(Map{}) { - return false - } - return true +// Context returns the context object provided at save time. +func (s Sink) Context() context.Context { + return s.internal.es.ctx } -// Validate validates all state functions. -func (fns *Fns) Validate(typ reflect.Type) bool { - return validateStateFn(fns.Save, typ) && validateStateFn(fns.Load, typ) +// Type is an interface that must be implemented by Struct objects. This allows +// these objects to be serialized while minimizing runtime reflection required. +// +// All these methods can be automatically generated by the go_statify tool. +type Type interface { + // StateTypeName returns the type's name. + // + // This is used for matching type information during encoding and + // decoding, as well as dynamic interface dispatch. This should be + // globally unique. + StateTypeName() string + + // StateFields returns information about the type. + // + // Fields is the set of fields for the object. Calls to Sink.Save and + // Source.Load must be made in-order with respect to these fields. + // + // This will be called at most once per serialization. + StateFields() []string } -type typeDatabase struct { - // nameToType is a forward lookup table. - nameToType map[string]reflect.Type - - // typeToName is the reverse lookup table. - typeToName map[reflect.Type]string +// SaverLoader must be implemented by struct types. +type SaverLoader interface { + // StateSave saves the state of the object to the given Map. + StateSave(Sink) - // typeToFns is the function lookup table. - typeToFns map[reflect.Type]Fns + // StateLoad loads the state of the object. + StateLoad(Source) } -// registeredTypes is a database used for SaveInterface and LoadInterface. -var registeredTypes = typeDatabase{ - nameToType: make(map[string]reflect.Type), - typeToName: make(map[reflect.Type]string), - typeToFns: make(map[reflect.Type]Fns), +// Source is used for Type.StateLoad. +type Source struct { + internal objectDecoder } -// register registers a type under the given name. This will generally be -// called via init() methods, and therefore uses panic to propagate errors. -func (t *typeDatabase) register(name string, typ reflect.Type, fns Fns) { - // We can't allow name collisions. - if ot, ok := t.nameToType[name]; ok { - panic(fmt.Sprintf("type %q can't use name %q, already in use by type %q", typ.Name(), name, ot.Name())) - } - - // Or multiple registrations. - if on, ok := t.typeToName[typ]; ok { - panic(fmt.Sprintf("type %q can't be registered as %q, already registered as %q", typ.Name(), name, on)) - } - - t.nameToType[name] = typ - t.typeToName[typ] = name - t.typeToFns[typ] = fns +// Load loads the given object passed as a pointer.. +// +// See Sink.Save for an example. +func (s Source) Load(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), false, nil) } -// lookupType finds a type given a name. -func (t *typeDatabase) lookupType(name string) (reflect.Type, bool) { - typ, ok := t.nameToType[name] - return typ, ok +// LoadWait loads the given objects from the map, and marks it as requiring all +// AfterLoad executions to complete prior to running this object's AfterLoad. +// +// See Sink.Save for an example. +func (s Source) LoadWait(slot int, objPtr interface{}) { + s.internal.load(slot, reflect.ValueOf(objPtr), true, nil) } -// lookupName finds a name given a type. -func (t *typeDatabase) lookupName(typ reflect.Type) (string, bool) { - name, ok := t.typeToName[typ] - return name, ok +// LoadValue loads the given object value from the map. +// +// See Sink.SaveValue for an example. +func (s Source) LoadValue(slot int, objPtr interface{}, fn func(interface{})) { + o := reflect.ValueOf(objPtr) + s.internal.load(slot, o, true, func() { fn(o.Elem().Interface()) }) } -// lookupFns finds functions given a type. -func (t *typeDatabase) lookupFns(typ reflect.Type) (Fns, bool) { - fns, ok := t.typeToFns[typ] - return fns, ok +// AfterLoad schedules a function execution when all objects have been +// allocated and their automated loading and customized load logic have been +// executed. fn will not be executed until all of current object's +// dependencies' AfterLoad() logic, if exist, have been executed. +func (s Source) AfterLoad(fn func()) { + s.internal.afterLoad(fn) } -// Register must be called for any interface implementation types that -// implements Loader. -// -// Register should be called either immediately after startup or via init() -// methods. Double registration of either names or types will result in a panic. -// -// No synchronization is provided; this should only be called in init. -// -// Example usage: -// -// state.Register("Foo", (*Foo)(nil), state.Fns{ -// Save: (*Foo).Save, -// Load: (*Foo).Load, -// }) -// -func Register(name string, instance interface{}, fns Fns) { - registeredTypes.register(name, reflect.TypeOf(instance), fns) +// Context returns the context object provided at load time. +func (s Source) Context() context.Context { + return s.internal.ds.ctx } // IsZeroValue checks if the given value is the zero value. @@ -244,72 +260,14 @@ func IsZeroValue(val interface{}) bool { return val == nil || reflect.ValueOf(val).Elem().IsZero() } -// step captures one encoding / decoding step. On each step, there is up to one -// choice made, which is captured by non-nil param. We intentionally do not -// eagerly create the final path string, as that will only be needed upon panic. -type step struct { - // dereference indicate if the current object is obtained by - // dereferencing a pointer. - dereference bool - - // format is the formatting string that takes param below, if - // non-nil. For example, in array indexing case, we have "[%d]". - format string - - // param stores the choice made at the current encoding / decoding step. - // For eaxmple, in array indexing case, param stores the index. When no - // choice is made, e.g. dereference, param should be nil. - param interface{} -} - -// recoverable is the state encoding / decoding panic recovery facility. It is -// also used to store encoding / decoding steps as well as the reference to the -// original queued object from which the current object is dispatched. The -// complete encoding / decoding path is synthesised from the steps in all queued -// objects leading to the current object. -type recoverable struct { - from *recoverable - steps []step +// Failf is a wrapper around panic that should be used to generate errors that +// can be caught during saving and loading. +func Failf(fmtStr string, v ...interface{}) { + panic(fmt.Errorf(fmtStr, v...)) } -// push enters a new context level. -func (sr *recoverable) push(dereference bool, format string, param interface{}) { - sr.steps = append(sr.steps, step{dereference, format, param}) -} - -// pop exits the current context level. -func (sr *recoverable) pop() { - if len(sr.steps) <= 1 { - return - } - sr.steps = sr.steps[:len(sr.steps)-1] -} - -// path returns the complete encoding / decoding path from root. This is only -// called upon panic. -func (sr *recoverable) path() string { - if sr.from == nil { - return "root" - } - p := sr.from.path() - for _, s := range sr.steps { - if s.dereference { - p = fmt.Sprintf("*(%s)", p) - } - if s.param == nil { - p += s.format - } else { - p += fmt.Sprintf(s.format, s.param) - } - } - return p -} - -func (sr *recoverable) copy() recoverable { - return recoverable{from: sr.from, steps: append([]step(nil), sr.steps...)} -} - -// safely executes the given function, catching a panic and unpacking as an error. +// safely executes the given function, catching a panic and unpacking as an +// error. // // The error flow through the state package uses panic and recover. There are // two important reasons for this: @@ -323,9 +281,15 @@ func (sr *recoverable) copy() recoverable { // method doesn't add a lot of value. If there are specific error conditions // that you'd like to handle, you should add appropriate functionality to // objects themselves prior to calling Save() and Load(). -func (sr *recoverable) safely(fn func()) (err error) { +func safely(fn func()) (err error) { defer func() { if r := recover(); r != nil { + if es, ok := r.(*ErrState); ok { + err = es // Propagate. + return + } + + // Build a new state error. es := new(ErrState) if e, ok := r.(error); ok { es.err = e @@ -333,8 +297,6 @@ func (sr *recoverable) safely(fn func()) (err error) { es.err = fmt.Errorf("%v", r) } - es.path = sr.path() - // Make a stack. We don't know how big it will be ahead // of time, but want to make sure we get the whole // thing. So we just do a stupid brute force approach. diff --git a/pkg/state/state_norace.go b/pkg/state/state_norace.go new file mode 100644 index 000000000..4281aed6d --- /dev/null +++ b/pkg/state/state_norace.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race + +package state + +var raceEnabled = false diff --git a/pkg/state/state_race.go b/pkg/state/state_race.go new file mode 100644 index 000000000..8232981ce --- /dev/null +++ b/pkg/state/state_race.go @@ -0,0 +1,19 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build race + +package state + +var raceEnabled = true diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go deleted file mode 100644 index d7221e9e8..000000000 --- a/pkg/state/state_test.go +++ /dev/null @@ -1,721 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package state - -import ( - "bytes" - "context" - "io/ioutil" - "math" - "reflect" - "testing" -) - -// TestCase is used to define a single success/failure testcase of -// serialization of a set of objects. -type TestCase struct { - // Name is the name of the test case. - Name string - - // Objects is the list of values to serialize. - Objects []interface{} - - // Fail is whether the test case is supposed to fail or not. - Fail bool -} - -// runTest runs all testcases. -func runTest(t *testing.T, tests []TestCase) { - for _, test := range tests { - t.Logf("TEST %s:", test.Name) - for i, root := range test.Objects { - t.Logf(" case#%d: %#v", i, root) - - // Save the passed object. - saveBuffer := &bytes.Buffer{} - saveObjectPtr := reflect.New(reflect.TypeOf(root)) - saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Save failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Save failed as expected: %v", err) - continue - } - - // Load a new copy of the object. - loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { - t.Errorf(" FAIL: Load failed unexpectedly: %v", err) - continue - } else if err != nil { - t.Logf(" PASS: Load failed as expected: %v", err) - continue - } - - // Compare the values. - loadedValue := loadObjectPtr.Elem().Interface() - if eq := reflect.DeepEqual(root, loadedValue); !eq && !test.Fail { - t.Errorf(" FAIL: Objects differs; got %#v", loadedValue) - continue - } else if !eq { - t.Logf(" PASS: Object different as expected.") - continue - } - - // Everything went okay. Is that good? - if test.Fail { - t.Errorf(" FAIL: Unexpected success.") - } else { - t.Logf(" PASS: Success.") - } - } - } -} - -// dumbStruct is a struct which does not implement the loader/saver interface. -// We expect that serialization of this struct will fail. -type dumbStruct struct { - A int - B int -} - -// smartStruct is a struct which does implement the loader/saver interface. -// We expect that serialization of this struct will succeed. -type smartStruct struct { - A int - B int -} - -func (s *smartStruct) save(m Map) { - m.Save("A", &s.A) - m.Save("B", &s.B) -} - -func (s *smartStruct) load(m Map) { - m.Load("A", &s.A) - m.Load("B", &s.B) -} - -// valueLoadStruct uses a value load. -type valueLoadStruct struct { - v int -} - -func (v *valueLoadStruct) save(m Map) { - m.SaveValue("v", v.v) -} - -func (v *valueLoadStruct) load(m Map) { - m.LoadValue("v", new(int), func(value interface{}) { - v.v = value.(int) - }) -} - -// afterLoadStruct has an AfterLoad function. -type afterLoadStruct struct { - v int -} - -func (a *afterLoadStruct) save(m Map) { -} - -func (a *afterLoadStruct) load(m Map) { - m.AfterLoad(func() { - a.v++ - }) -} - -// genericContainer is a generic dispatcher. -type genericContainer struct { - v interface{} -} - -func (g *genericContainer) save(m Map) { - m.Save("v", &g.v) -} - -func (g *genericContainer) load(m Map) { - m.Load("v", &g.v) -} - -// sliceContainer is a generic slice. -type sliceContainer struct { - v []interface{} -} - -func (s *sliceContainer) save(m Map) { - m.Save("v", &s.v) -} - -func (s *sliceContainer) load(m Map) { - m.Load("v", &s.v) -} - -// mapContainer is a generic map. -type mapContainer struct { - v map[int]interface{} -} - -func (mc *mapContainer) save(m Map) { - m.Save("v", &mc.v) -} - -func (mc *mapContainer) load(m Map) { - // Some of the test cases below assume legacy behavior wherein maps - // will automatically inherit dependencies. - m.LoadWait("v", &mc.v) -} - -// dumbMap is a map which does not implement the loader/saver interface. -// Serialization of this map will default to the standard encode/decode logic. -type dumbMap map[string]int - -// pointerStruct contains various pointers, shared and non-shared, and pointers -// to pointers. We expect that serialization will respect the structure. -type pointerStruct struct { - A *int - B *int - C *int - D *int - - AA **int - BB **int -} - -func (p *pointerStruct) save(m Map) { - m.Save("A", &p.A) - m.Save("B", &p.B) - m.Save("C", &p.C) - m.Save("D", &p.D) - m.Save("AA", &p.AA) - m.Save("BB", &p.BB) -} - -func (p *pointerStruct) load(m Map) { - m.Load("A", &p.A) - m.Load("B", &p.B) - m.Load("C", &p.C) - m.Load("D", &p.D) - m.Load("AA", &p.AA) - m.Load("BB", &p.BB) -} - -// testInterface is a trivial interface example. -type testInterface interface { - Foo() -} - -// testImpl is a trivial implementation of testInterface. -type testImpl struct { -} - -// Foo satisfies testInterface. -func (t *testImpl) Foo() { -} - -// testImpl is trivially serializable. -func (t *testImpl) save(m Map) { -} - -// testImpl is trivially serializable. -func (t *testImpl) load(m Map) { -} - -// testI demonstrates interface dispatching. -type testI struct { - I testInterface -} - -func (t *testI) save(m Map) { - m.Save("I", &t.I) -} - -func (t *testI) load(m Map) { - m.Load("I", &t.I) -} - -// cycleStruct is used to implement basic cycles. -type cycleStruct struct { - c *cycleStruct -} - -func (c *cycleStruct) save(m Map) { - m.Save("c", &c.c) -} - -func (c *cycleStruct) load(m Map) { - m.Load("c", &c.c) -} - -// badCycleStruct actually has deadlocking dependencies. -// -// This should pass if b.b = {nil|b} and fail otherwise. -type badCycleStruct struct { - b *badCycleStruct -} - -func (b *badCycleStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *badCycleStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(func() { - // This is not executable, since AfterLoad requires that the - // object and all dependencies are complete. This should cause - // a deadlock error during load. - }) -} - -// emptyStructPointer points to an empty struct. -type emptyStructPointer struct { - nothing *struct{} -} - -func (e *emptyStructPointer) save(m Map) { - m.Save("nothing", &e.nothing) -} - -func (e *emptyStructPointer) load(m Map) { - m.Load("nothing", &e.nothing) -} - -// truncateInteger truncates an integer. -type truncateInteger struct { - v int64 - v2 int32 -} - -func (t *truncateInteger) save(m Map) { - t.v2 = int32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = int64(t.v2) -} - -// truncateUnsignedInteger truncates an unsigned integer. -type truncateUnsignedInteger struct { - v uint64 - v2 uint32 -} - -func (t *truncateUnsignedInteger) save(m Map) { - t.v2 = uint32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateUnsignedInteger) load(m Map) { - m.Load("v", &t.v2) - t.v = uint64(t.v2) -} - -// truncateFloat truncates a floating point number. -type truncateFloat struct { - v float64 - v2 float32 -} - -func (t *truncateFloat) save(m Map) { - t.v2 = float32(t.v) - m.Save("v", &t.v) -} - -func (t *truncateFloat) load(m Map) { - m.Load("v", &t.v2) - t.v = float64(t.v2) -} - -func TestTypes(t *testing.T) { - // x and y are basic integers, while xp points to x. - x := 1 - y := 2 - xp := &x - - // cs is a single object cycle. - cs := cycleStruct{nil} - cs.c = &cs - - // cs1 and cs2 are in a two object cycle. - cs1 := cycleStruct{nil} - cs2 := cycleStruct{nil} - cs1.c = &cs2 - cs2.c = &cs1 - - // bs is a single object cycle. - bs := badCycleStruct{nil} - bs.b = &bs - - // bs2 and bs2 are in a deadlocking cycle. - bs1 := badCycleStruct{nil} - bs2 := badCycleStruct{nil} - bs1.b = &bs2 - bs2.b = &bs1 - - // regular nils. - var ( - nilmap dumbMap - nilslice []byte - ) - - // embed points to embedded fields. - embed1 := pointerStruct{} - embed1.AA = &embed1.A - embed2 := pointerStruct{} - embed2.BB = &embed2.B - - // es1 contains two structs pointing to the same empty struct. - es := emptyStructPointer{new(struct{})} - es1 := []emptyStructPointer{es, es} - - tests := []TestCase{ - { - Name: "bool", - Objects: []interface{}{ - true, - false, - }, - }, - { - Name: "integers", - Objects: []interface{}{ - int(0), - int(1), - int(-1), - int8(0), - int8(1), - int8(-1), - int16(0), - int16(1), - int16(-1), - int32(0), - int32(1), - int32(-1), - int64(0), - int64(1), - int64(-1), - }, - }, - { - Name: "unsigned integers", - Objects: []interface{}{ - uint(0), - uint(1), - uint8(0), - uint8(1), - uint16(0), - uint16(1), - uint32(1), - uint64(0), - uint64(1), - }, - }, - { - Name: "strings", - Objects: []interface{}{ - "", - "foo", - "bar", - "\xa0", - }, - }, - { - Name: "slices", - Objects: []interface{}{ - []int{-1, 0, 1}, - []*int{&x, &x, &x}, - []int{1, 2, 3}[0:1], - []int{1, 2, 3}[1:2], - make([]byte, 32), - make([]byte, 32)[:16], - make([]byte, 32)[:16:20], - nilslice, - }, - }, - { - Name: "arrays", - Objects: []interface{}{ - &[1048576]bool{false, true, false, true}, - &[1048576]uint8{0, 1, 2, 3}, - &[1048576]byte{0, 1, 2, 3}, - &[1048576]uint16{0, 1, 2, 3}, - &[1048576]uint{0, 1, 2, 3}, - &[1048576]uint32{0, 1, 2, 3}, - &[1048576]uint64{0, 1, 2, 3}, - &[1048576]uintptr{0, 1, 2, 3}, - &[1048576]int8{0, -1, -2, -3}, - &[1048576]int16{0, -1, -2, -3}, - &[1048576]int32{0, -1, -2, -3}, - &[1048576]int64{0, -1, -2, -3}, - &[1048576]float32{0, 1.1, 2.2, 3.3}, - &[1048576]float64{0, 1.1, 2.2, 3.3}, - }, - }, - { - Name: "pointers", - Objects: []interface{}{ - &pointerStruct{A: &x, B: &x, C: &y, D: &y, AA: &xp, BB: &xp}, - &pointerStruct{}, - }, - }, - { - Name: "empty struct", - Objects: []interface{}{ - struct{}{}, - }, - }, - { - Name: "unenlightened structs", - Objects: []interface{}{ - &dumbStruct{A: 1, B: 2}, - }, - Fail: true, - }, - { - Name: "enlightened structs", - Objects: []interface{}{ - &smartStruct{A: 1, B: 2}, - }, - }, - { - Name: "load-hooks", - Objects: []interface{}{ - &afterLoadStruct{v: 1}, - &valueLoadStruct{v: 1}, - &genericContainer{v: &afterLoadStruct{v: 1}}, - &genericContainer{v: &valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, - }, - }, - { - Name: "maps", - Objects: []interface{}{ - dumbMap{"a": -1, "b": 0, "c": 1}, - map[smartStruct]int{{}: 0, {A: 1}: 1}, - nilmap, - &mapContainer{v: map[int]interface{}{0: &smartStruct{A: 1}}}, - }, - }, - { - Name: "interfaces", - Objects: []interface{}{ - &testI{&testImpl{}}, - &testI{nil}, - &testI{(*testImpl)(nil)}, - }, - }, - { - Name: "unregistered-interfaces", - Objects: []interface{}{ - &genericContainer{v: afterLoadStruct{v: 1}}, - &genericContainer{v: valueLoadStruct{v: 1}}, - &sliceContainer{v: []interface{}{afterLoadStruct{v: 1}}}, - &sliceContainer{v: []interface{}{valueLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: afterLoadStruct{v: 1}}}, - &mapContainer{v: map[int]interface{}{0: valueLoadStruct{v: 1}}}, - }, - Fail: true, - }, - { - Name: "cycles", - Objects: []interface{}{ - &cs, - &cs1, - &cycleStruct{&cs1}, - &cycleStruct{&cs}, - &badCycleStruct{nil}, - &bs, - }, - }, - { - Name: "deadlock", - Objects: []interface{}{ - &bs1, - }, - Fail: true, - }, - { - Name: "embed", - Objects: []interface{}{ - &embed1, - &embed2, - }, - Fail: true, - }, - { - Name: "empty structs", - Objects: []interface{}{ - new(struct{}), - es, - es1, - }, - }, - { - Name: "truncated okay", - Objects: []interface{}{ - &truncateInteger{v: 1}, - &truncateUnsignedInteger{v: 1}, - &truncateFloat{v: 1.0}, - }, - }, - { - Name: "truncated bad", - Objects: []interface{}{ - &truncateInteger{v: math.MaxInt32 + 1}, - &truncateUnsignedInteger{v: math.MaxUint32 + 1}, - &truncateFloat{v: math.MaxFloat32 * 2}, - }, - Fail: true, - }, - } - - runTest(t, tests) -} - -// benchStruct is used for benchmarking. -type benchStruct struct { - b *benchStruct - - // Dummy data is included to ensure that these objects are large. - // This is to detect possible regression when registering objects. - _ [4096]byte -} - -func (b *benchStruct) save(m Map) { - m.Save("b", &b.b) -} - -func (b *benchStruct) load(m Map) { - m.LoadWait("b", &b.b) - m.AfterLoad(b.afterLoad) -} - -func (b *benchStruct) afterLoad() { - // Do nothing, just force scheduling. -} - -// buildObject builds a benchmark object. -func buildObject(n int) (b *benchStruct) { - for i := 0; i < n; i++ { - b = &benchStruct{b: b} - } - return -} - -func BenchmarkEncoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var stats Stats - b.StartTimer() - if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { - b.Errorf("save failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func BenchmarkDecoding(b *testing.B) { - b.StopTimer() - bs := buildObject(b.N) - var newBS benchStruct - buf := &bytes.Buffer{} - if err := Save(context.Background(), buf, bs, nil); err != nil { - b.Errorf("save failed: %v", err) - } - var stats Stats - b.StartTimer() - if err := Load(context.Background(), buf, &newBS, &stats); err != nil { - b.Errorf("load failed: %v", err) - } - b.StopTimer() - if b.N > 1000 { - b.Logf("breakdown (n=%d): %s", b.N, &stats) - } -} - -func init() { - Register("stateTest.smartStruct", (*smartStruct)(nil), Fns{ - Save: (*smartStruct).save, - Load: (*smartStruct).load, - }) - Register("stateTest.afterLoadStruct", (*afterLoadStruct)(nil), Fns{ - Save: (*afterLoadStruct).save, - Load: (*afterLoadStruct).load, - }) - Register("stateTest.valueLoadStruct", (*valueLoadStruct)(nil), Fns{ - Save: (*valueLoadStruct).save, - Load: (*valueLoadStruct).load, - }) - Register("stateTest.genericContainer", (*genericContainer)(nil), Fns{ - Save: (*genericContainer).save, - Load: (*genericContainer).load, - }) - Register("stateTest.sliceContainer", (*sliceContainer)(nil), Fns{ - Save: (*sliceContainer).save, - Load: (*sliceContainer).load, - }) - Register("stateTest.mapContainer", (*mapContainer)(nil), Fns{ - Save: (*mapContainer).save, - Load: (*mapContainer).load, - }) - Register("stateTest.pointerStruct", (*pointerStruct)(nil), Fns{ - Save: (*pointerStruct).save, - Load: (*pointerStruct).load, - }) - Register("stateTest.testImpl", (*testImpl)(nil), Fns{ - Save: (*testImpl).save, - Load: (*testImpl).load, - }) - Register("stateTest.testI", (*testI)(nil), Fns{ - Save: (*testI).save, - Load: (*testI).load, - }) - Register("stateTest.cycleStruct", (*cycleStruct)(nil), Fns{ - Save: (*cycleStruct).save, - Load: (*cycleStruct).load, - }) - Register("stateTest.badCycleStruct", (*badCycleStruct)(nil), Fns{ - Save: (*badCycleStruct).save, - Load: (*badCycleStruct).load, - }) - Register("stateTest.emptyStructPointer", (*emptyStructPointer)(nil), Fns{ - Save: (*emptyStructPointer).save, - Load: (*emptyStructPointer).load, - }) - Register("stateTest.truncateInteger", (*truncateInteger)(nil), Fns{ - Save: (*truncateInteger).save, - Load: (*truncateInteger).load, - }) - Register("stateTest.truncateUnsignedInteger", (*truncateUnsignedInteger)(nil), Fns{ - Save: (*truncateUnsignedInteger).save, - Load: (*truncateUnsignedInteger).load, - }) - Register("stateTest.truncateFloat", (*truncateFloat)(nil), Fns{ - Save: (*truncateFloat).save, - Load: (*truncateFloat).load, - }) - Register("stateTest.benchStruct", (*benchStruct)(nil), Fns{ - Save: (*benchStruct).save, - Load: (*benchStruct).load, - }) -} diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD index e7581c09b..d6c89c7e9 100644 --- a/pkg/state/statefile/BUILD +++ b/pkg/state/statefile/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/binary", "//pkg/compressio", + "//pkg/state/wire", ], ) diff --git a/pkg/state/statefile/statefile.go b/pkg/state/statefile/statefile.go index c0f4c4954..bdfb800fb 100644 --- a/pkg/state/statefile/statefile.go +++ b/pkg/state/statefile/statefile.go @@ -57,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/compressio" + "gvisor.dev/gvisor/pkg/state/wire" ) // keySize is the AES-256 key length. @@ -83,10 +84,16 @@ var ErrInvalidMetadataLength = fmt.Errorf("metadata length invalid, maximum size // ErrMetadataInvalid is returned if passed metadata is invalid. var ErrMetadataInvalid = fmt.Errorf("metadata invalid, can't start with _") +// WriteCloser is an io.Closer and wire.Writer. +type WriteCloser interface { + wire.Writer + io.Closer +} + // NewWriter returns a state data writer for a statefile. // // Note that the returned WriteCloser must be closed. -func NewWriter(w io.Writer, key []byte, metadata map[string]string) (io.WriteCloser, error) { +func NewWriter(w io.Writer, key []byte, metadata map[string]string) (WriteCloser, error) { if metadata == nil { metadata = make(map[string]string) } @@ -215,7 +222,7 @@ func metadata(r io.Reader, h hash.Hash) (map[string]string, error) { } // NewReader returns a reader for a statefile. -func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { +func NewReader(r io.Reader, key []byte) (wire.Reader, map[string]string, error) { // Read the metadata with the hash. h := hmac.New(sha256.New, key) metadata, err := metadata(r, h) @@ -224,9 +231,9 @@ func NewReader(r io.Reader, key []byte) (io.Reader, map[string]string, error) { } // Wrap in compression. - rc, err := compressio.NewReader(r, key) + cr, err := compressio.NewReader(r, key) if err != nil { return nil, nil, err } - return rc, metadata, nil + return cr, metadata, nil } diff --git a/pkg/state/stats.go b/pkg/state/stats.go index eb51cda47..eaec664a1 100644 --- a/pkg/state/stats.go +++ b/pkg/state/stats.go @@ -17,7 +17,6 @@ package state import ( "bytes" "fmt" - "reflect" "sort" "time" ) @@ -35,92 +34,81 @@ type statEntry struct { // All exported receivers accept nil. type Stats struct { // byType contains a breakdown of time spent by type. - byType map[reflect.Type]*statEntry + // + // This is indexed *directly* by typeID, including zero. + byType []statEntry // stack contains objects in progress. - stack []reflect.Type + stack []typeID + + // names contains type names. + // + // This is also indexed *directly* by typeID, including zero, which we + // hard-code as "state.default". This is only resolved by calling fini + // on the stats object. + names []string // last is the last start time. last time.Time } -// sample adds the samples to the given object. -func (s *Stats) sample(typ reflect.Type) { - now := time.Now() - s.byType[typ].total += now.Sub(s.last) - s.last = now +// init initializes statistics. +func (s *Stats) init() { + s.last = time.Now() + s.stack = append(s.stack, 0) } -// Add adds a sample count. -func (s *Stats) Add(obj reflect.Value) { - if s == nil { - return - } - if s.byType == nil { - s.byType = make(map[reflect.Type]*statEntry) - } - typ := obj.Type() - entry, ok := s.byType[typ] - if !ok { - entry = new(statEntry) - s.byType[typ] = entry +// fini finalizes statistics. +func (s *Stats) fini(resolve func(id typeID) string) { + s.done() + + // Resolve all type names. + s.names = make([]string, len(s.byType)) + s.names[0] = "state.default" // See above. + for id := typeID(1); int(id) < len(s.names); id++ { + s.names[id] = resolve(id) } - entry.count++ } -// Remove removes a sample count. It should only be called after a previous -// Add(). -func (s *Stats) Remove(obj reflect.Value) { - if s == nil { - return +// sample adds the samples to the given object. +func (s *Stats) sample(id typeID) { + now := time.Now() + if len(s.byType) <= int(id) { + // Allocate all the missing entries in one fell swoop. + s.byType = append(s.byType, make([]statEntry, 1+int(id)-len(s.byType))...) } - typ := obj.Type() - entry := s.byType[typ] - entry.count-- + s.byType[id].total += now.Sub(s.last) + s.last = now } -// Start starts a sample. -func (s *Stats) Start(obj reflect.Value) { - if s == nil { - return - } - if len(s.stack) > 0 { - last := s.stack[len(s.stack)-1] - s.sample(last) - } else { - // First time sample. - s.last = time.Now() - } - s.stack = append(s.stack, obj.Type()) +// start starts a sample. +func (s *Stats) start(id typeID) { + last := s.stack[len(s.stack)-1] + s.sample(last) + s.stack = append(s.stack, id) } -// Done finishes the current sample. -func (s *Stats) Done() { - if s == nil { - return - } +// done finishes the current sample. +func (s *Stats) done() { last := s.stack[len(s.stack)-1] s.sample(last) + s.byType[last].count++ s.stack = s.stack[:len(s.stack)-1] } type sliceEntry struct { - typ reflect.Type + name string entry *statEntry } // String returns a table representation of the stats. func (s *Stats) String() string { - if s == nil || len(s.byType) == 0 { - return "(no data)" - } - // Build a list of stat entries. ss := make([]sliceEntry, 0, len(s.byType)) - for typ, entry := range s.byType { + for id := 0; id < len(s.names); id++ { ss = append(ss, sliceEntry{ - typ: typ, - entry: entry, + name: s.names[id], + entry: &s.byType[id], }) } @@ -136,17 +124,22 @@ func (s *Stats) String() string { total time.Duration ) buf.WriteString("\n") - buf.WriteString(fmt.Sprintf("%12s | %8s | %8s | %s\n", "total", "count", "per", "type")) - buf.WriteString("-------------+----------+----------+-------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8s | % 16s | %s\n", "total", "count", "per", "type")) + buf.WriteString("-----------------+----------+------------------+----------------\n") for _, se := range ss { + if se.entry.count == 0 { + // Since we store all types linearly, we are not + // guaranteed that any entry actually has time. + continue + } count += se.entry.count total += se.entry.total per := se.entry.total / time.Duration(se.entry.count) - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | %s\n", - se.entry.total, se.entry.count, per, se.typ.String())) + buf.WriteString(fmt.Sprintf("% 16s | %8d | % 16s | %s\n", + se.entry.total, se.entry.count, per, se.name)) } - buf.WriteString("-------------+----------+----------+-------------\n") - buf.WriteString(fmt.Sprintf("%12s | %8d | %8s | [all]", + buf.WriteString("-----------------+----------+------------------+----------------\n") + buf.WriteString(fmt.Sprintf("% 16s | % 8d | % 16s | [all]", total, count, total/time.Duration(count))) return string(buf.Bytes()) } diff --git a/pkg/state/tests/BUILD b/pkg/state/tests/BUILD new file mode 100644 index 000000000..9297cafbe --- /dev/null +++ b/pkg/state/tests/BUILD @@ -0,0 +1,43 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tests", + srcs = [ + "array.go", + "bench.go", + "integer.go", + "load.go", + "map.go", + "register.go", + "struct.go", + "tests.go", + ], + deps = [ + "//pkg/state", + "//pkg/state/pretty", + ], +) + +go_test( + name = "tests_test", + size = "small", + srcs = [ + "array_test.go", + "bench_test.go", + "bool_test.go", + "float_test.go", + "integer_test.go", + "load_test.go", + "map_test.go", + "register_test.go", + "string_test.go", + "struct_test.go", + ], + library = ":tests", + deps = [ + "//pkg/state", + "//pkg/state/wire", + ], +) diff --git a/pkg/state/tests/array.go b/pkg/state/tests/array.go new file mode 100644 index 000000000..0972a80e7 --- /dev/null +++ b/pkg/state/tests/array.go @@ -0,0 +1,35 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type arrayContainer struct { + v [1]interface{} +} + +// +stateify savable +type arrayPtrContainer struct { + v *[1]interface{} +} + +// +stateify savable +type sliceContainer struct { + v []interface{} +} + +// +stateify savable +type slicePtrContainer struct { + v *[]interface{} +} diff --git a/pkg/state/tests/array_test.go b/pkg/state/tests/array_test.go new file mode 100644 index 000000000..a347b2947 --- /dev/null +++ b/pkg/state/tests/array_test.go @@ -0,0 +1,134 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allArrayPrimitives = []interface{}{ + [1]bool{}, + [1]bool{true}, + [2]bool{false, true}, + [1]int{}, + [1]int{1}, + [2]int{0, 1}, + [1]int8{}, + [1]int8{1}, + [2]int8{0, 1}, + [1]int16{}, + [1]int16{1}, + [2]int16{0, 1}, + [1]int32{}, + [1]int32{1}, + [2]int32{0, 1}, + [1]int64{}, + [1]int64{1}, + [2]int64{0, 1}, + [1]uint{}, + [1]uint{1}, + [2]uint{0, 1}, + [1]uintptr{}, + [1]uintptr{1}, + [2]uintptr{0, 1}, + [1]uint8{}, + [1]uint8{1}, + [2]uint8{0, 1}, + [1]uint16{}, + [1]uint16{1}, + [2]uint16{0, 1}, + [1]uint32{}, + [1]uint32{1}, + [2]uint32{0, 1}, + [1]uint64{}, + [1]uint64{1}, + [2]uint64{0, 1}, + [1]string{}, + [1]string{""}, + [1]string{nonEmptyString}, + [2]string{"", nonEmptyString}, +} + +func TestArrayPrimitives(t *testing.T) { + runTestCases(t, false, "plain", flatten(allArrayPrimitives)) + runTestCases(t, false, "pointers", pointersTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allArrayPrimitives))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allArrayPrimitives)))) +} + +func TestSlices(t *testing.T) { + var allSlices = flatten( + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + return v.Slice(0, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the pure "nil" value for the slice. + return reflect.New(v.Slice(0, 0).Type()).Elem().Interface(), true + } + return v.Slice(1, v.Len()).Interface(), true + }), + filter(allArrayPrimitives, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)).Elem() + v.Set(reflect.ValueOf(o)) + if v.Len() == 0 { + // Return the zero-valued slice. + return reflect.MakeSlice(v.Slice(0, 0).Type(), 0, 0).Interface(), true + } + return v.Slice(0, v.Len()-1).Interface(), true + }), + ) + runTestCases(t, false, "plain", allSlices) + runTestCases(t, false, "pointers", pointersTo(allSlices)) + runTestCases(t, false, "interfaces", interfacesTo(allSlices)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(allSlices))) +} + +func TestArrayContainers(t *testing.T) { + var ( + emptyArray [1]interface{} + fullArray [1]interface{} + ) + fullArray[0] = &emptyArray + runTestCases(t, false, "", []interface{}{ + arrayContainer{v: emptyArray}, + arrayContainer{v: fullArray}, + arrayPtrContainer{v: nil}, + arrayPtrContainer{v: &emptyArray}, + arrayPtrContainer{v: &fullArray}, + }) +} + +func TestSliceContainers(t *testing.T) { + var ( + nilSlice []interface{} + emptySlice = make([]interface{}, 0) + fullSlice = []interface{}{nil} + ) + runTestCases(t, false, "", []interface{}{ + sliceContainer{v: nilSlice}, + sliceContainer{v: emptySlice}, + sliceContainer{v: fullSlice}, + slicePtrContainer{v: nil}, + slicePtrContainer{v: &nilSlice}, + slicePtrContainer{v: &emptySlice}, + slicePtrContainer{v: &fullSlice}, + }) +} diff --git a/pkg/state/tests/bench.go b/pkg/state/tests/bench.go new file mode 100644 index 000000000..40869cdfb --- /dev/null +++ b/pkg/state/tests/bench.go @@ -0,0 +1,24 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type benchStruct struct { + B *benchStruct // Must be exported for gob. +} + +func (b *benchStruct) afterLoad() { + // Do nothing, just force scheduling. +} diff --git a/pkg/state/tests/bench_test.go b/pkg/state/tests/bench_test.go new file mode 100644 index 000000000..7e102c907 --- /dev/null +++ b/pkg/state/tests/bench_test.go @@ -0,0 +1,153 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/wire" +) + +// buildPtrObject builds a benchmark object. +func buildPtrObject(n int) interface{} { + b := new(benchStruct) + for i := 0; i < n; i++ { + b = &benchStruct{B: b} + } + return b +} + +// buildMapObject builds a benchmark object. +func buildMapObject(n int) interface{} { + b := new(benchStruct) + m := make(map[int]*benchStruct) + for i := 0; i < n; i++ { + m[i] = b + } + return &m +} + +// buildSliceObject builds a benchmark object. +func buildSliceObject(n int) interface{} { + b := new(benchStruct) + s := make([]*benchStruct, 0, n) + for i := 0; i < n; i++ { + s = append(s, b) + } + return &s +} + +var allObjects = map[string]struct { + New func(int) interface{} +}{ + "ptr": { + New: buildPtrObject, + }, + "map": { + New: buildMapObject, + }, + "slice": { + New: buildSliceObject, + }, +} + +func buildObjects(n int, fn func(int) interface{}) (iters int, v interface{}) { + // maxSize is the maximum size of an individual object below. For an N + // larger than this, we start to return multiple objects. + const maxSize = 1024 + if n <= maxSize { + return 1, fn(n) + } + iters = (n + maxSize - 1) / maxSize + return iters, fn(maxSize) +} + +// gobSave is a version of save using gob (no stats available). +func gobSave(_ context.Context, w wire.Writer, v interface{}) (_ state.Stats, err error) { + enc := gob.NewEncoder(w) + err = enc.Encode(v) + return +} + +// gobLoad is a version of load using gob (no stats available). +func gobLoad(_ context.Context, r wire.Reader, v interface{}) (_ state.Stats, err error) { + dec := gob.NewDecoder(r) + err = dec.Decode(v) + return +} + +var allAlgos = map[string]struct { + Save func(context.Context, wire.Writer, interface{}) (state.Stats, error) + Load func(context.Context, wire.Reader, interface{}) (state.Stats, error) + MaxPtr int +}{ + "state": { + Save: state.Save, + Load: state.Load, + }, + "gob": { + Save: gobSave, + Load: gobLoad, + }, +} + +func BenchmarkEncoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + b.ReportAllocs() + b.StartTimer() + for i := 0; i < n; i++ { + if _, err := algoInfo.Save(context.Background(), discard{}, v); err != nil { + b.Errorf("save failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} + +func BenchmarkDecoding(b *testing.B) { + for objName, objInfo := range allObjects { + for algoName, algoInfo := range allAlgos { + b.Run(fmt.Sprintf("%s/%s", objName, algoName), func(b *testing.B) { + b.StopTimer() + n, v := buildObjects(b.N, objInfo.New) + buf := new(bytes.Buffer) + if _, err := algoInfo.Save(context.Background(), buf, v); err != nil { + b.Errorf("save failed: %v", err) + } + b.ReportAllocs() + b.StartTimer() + var r bytes.Reader + for i := 0; i < n; i++ { + r.Reset(buf.Bytes()) + if _, err := algoInfo.Load(context.Background(), &r, v); err != nil { + b.Errorf("load failed: %v", err) + } + } + b.StopTimer() + }) + } + } +} diff --git a/pkg/state/tests/bool_test.go b/pkg/state/tests/bool_test.go new file mode 100644 index 000000000..e17cfacf9 --- /dev/null +++ b/pkg/state/tests/bool_test.go @@ -0,0 +1,31 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +var allBools = []bool{ + true, + false, +} + +func TestBool(t *testing.T) { + runTestCases(t, false, "plain", flatten(allBools)) + runTestCases(t, false, "pointers", pointersTo(flatten(allBools))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allBools))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allBools)))) +} diff --git a/pkg/state/tests/float_test.go b/pkg/state/tests/float_test.go new file mode 100644 index 000000000..3e89edd9c --- /dev/null +++ b/pkg/state/tests/float_test.go @@ -0,0 +1,118 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var safeFloat32s = []float32{ + float32(0.0), + float32(1.0), + float32(-1.0), + float32(math.Inf(1)), + float32(math.Inf(-1)), +} + +var allFloat32s = append(safeFloat32s, float32(math.NaN())) + +var safeFloat64s = []float64{ + float64(0.0), + float64(1.0), + float64(-1.0), + math.Inf(1), + math.Inf(-1), +} + +var allFloat64s = append(safeFloat64s, math.NaN()) + +func TestFloat(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allFloat32s, + allFloat64s, + )) + // See checkEqual for why NaNs are missing. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfaces", interfacesTo(flatten( + safeFloat32s, + safeFloat64s, + ))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten( + safeFloat32s, + safeFloat64s, + )))) +} + +const onlyDouble float64 = 1.0000000000000002 + +func TestFloatTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingFloat32{save: onlyDouble}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingFloat32{save: 1.0}, + }) +} + +var safeComplex64s = combine(safeFloat32s, safeFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var allComplex64s = combine(allFloat32s, allFloat32s, func(i, j interface{}) interface{} { + return complex(i.(float32), j.(float32)) +}) + +var safeComplex128s = combine(safeFloat64s, safeFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +var allComplex128s = combine(allFloat64s, allFloat64s, func(i, j interface{}) interface{} { + return complex(i.(float64), j.(float64)) +}) + +func TestComplex(t *testing.T) { + runTestCases(t, false, "plain", flatten( + allComplex64s, + allComplex128s, + )) + // See TestFloat; same issue. + runTestCases(t, false, "pointers", pointersTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacse", interfacesTo(flatten( + safeComplex64s, + safeComplex128s, + ))) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(flatten( + safeComplex64s, + safeComplex128s, + )))) +} + +func TestComplexTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingComplex64{save: complex(onlyDouble, onlyDouble)}, + truncatingComplex64{save: complex(1.0, onlyDouble)}, + truncatingComplex64{save: complex(onlyDouble, 1.0)}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingComplex64{save: complex(1.0, 1.0)}, + }) +} diff --git a/pkg/state/tests/integer.go b/pkg/state/tests/integer.go new file mode 100644 index 000000000..ca403eed1 --- /dev/null +++ b/pkg/state/tests/integer.go @@ -0,0 +1,163 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "gvisor.dev/gvisor/pkg/state" +) + +// +stateify type +type truncatingUint8 struct { + save uint64 + load uint8 `state:"nosave"` +} + +func (t *truncatingUint8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint8)(nil) + +// +stateify type +type truncatingUint16 struct { + save uint64 + load uint16 `state:"nosave"` +} + +func (t *truncatingUint16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint16)(nil) + +// +stateify type +type truncatingUint32 struct { + save uint64 + load uint32 `state:"nosave"` +} + +func (t *truncatingUint32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingUint32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = uint64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingUint32)(nil) + +// +stateify type +type truncatingInt8 struct { + save int64 + load int8 `state:"nosave"` +} + +func (t *truncatingInt8) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt8) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt8)(nil) + +// +stateify type +type truncatingInt16 struct { + save int64 + load int16 `state:"nosave"` +} + +func (t *truncatingInt16) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt16) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt16)(nil) + +// +stateify type +type truncatingInt32 struct { + save int64 + load int32 `state:"nosave"` +} + +func (t *truncatingInt32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingInt32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = int64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingInt32)(nil) + +// +stateify type +type truncatingFloat32 struct { + save float64 + load float32 `state:"nosave"` +} + +func (t *truncatingFloat32) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingFloat32) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = float64(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingFloat32)(nil) + +// +stateify type +type truncatingComplex64 struct { + save complex128 + load complex64 `state:"nosave"` +} + +func (t *truncatingComplex64) StateSave(m state.Sink) { + m.Save(0, &t.save) +} + +func (t *truncatingComplex64) StateLoad(m state.Source) { + m.Load(0, &t.load) + t.save = complex128(t.load) + t.load = 0 +} + +var _ state.SaverLoader = (*truncatingComplex64)(nil) diff --git a/pkg/state/tests/integer_test.go b/pkg/state/tests/integer_test.go new file mode 100644 index 000000000..d3931c952 --- /dev/null +++ b/pkg/state/tests/integer_test.go @@ -0,0 +1,94 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "math" + "testing" +) + +var ( + allIntTs = []int{-1, 0, 1} + allInt8s = []int8{math.MinInt8, -1, 0, 1, math.MaxInt8} + allInt16s = []int16{math.MinInt16, -1, 0, 1, math.MaxInt16} + allInt32s = []int32{math.MinInt32, -1, 0, 1, math.MaxInt32} + allInt64s = []int64{math.MinInt64, -1, 0, 1, math.MaxInt64} + allUintTs = []uint{0, 1} + allUintptrs = []uintptr{0, 1, ^uintptr(0)} + allUint8s = []uint8{0, 1, math.MaxUint8} + allUint16s = []uint16{0, 1, math.MaxUint16} + allUint32s = []uint32{0, 1, math.MaxUint32} + allUint64s = []uint64{0, 1, math.MaxUint64} +) + +var allInts = flatten( + allIntTs, + allInt8s, + allInt16s, + allInt32s, + allInt64s, +) + +var allUints = flatten( + allUintTs, + allUintptrs, + allUint8s, + allUint16s, + allUint32s, + allUint64s, +) + +func TestInt(t *testing.T) { + runTestCases(t, false, "plain", allInts) + runTestCases(t, false, "pointers", pointersTo(allInts)) + runTestCases(t, false, "interfaces", interfacesTo(allInts)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allInts))) +} + +func TestIntTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingInt8{save: math.MinInt8 - 1}, + truncatingInt16{save: math.MinInt16 - 1}, + truncatingInt32{save: math.MinInt32 - 1}, + truncatingInt8{save: math.MaxInt8 + 1}, + truncatingInt16{save: math.MaxInt16 + 1}, + truncatingInt32{save: math.MaxInt32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingInt8{save: 1}, + truncatingInt16{save: 1}, + truncatingInt32{save: 1}, + }) +} + +func TestUint(t *testing.T) { + runTestCases(t, false, "plain", allUints) + runTestCases(t, false, "pointers", pointersTo(allUints)) + runTestCases(t, false, "interfaces", interfacesTo(allUints)) + runTestCases(t, false, "interfacesTo", interfacesTo(pointersTo(allUints))) +} + +func TestUintTruncation(t *testing.T) { + runTestCases(t, true, "pass", []interface{}{ + truncatingUint8{save: math.MaxUint8 + 1}, + truncatingUint16{save: math.MaxUint16 + 1}, + truncatingUint32{save: math.MaxUint32 + 1}, + }) + runTestCases(t, false, "fail", []interface{}{ + truncatingUint8{save: 1}, + truncatingUint16{save: 1}, + truncatingUint32{save: 1}, + }) +} diff --git a/pkg/state/tests/load.go b/pkg/state/tests/load.go new file mode 100644 index 000000000..a8350c0f3 --- /dev/null +++ b/pkg/state/tests/load.go @@ -0,0 +1,61 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type genericContainer struct { + v interface{} +} + +// +stateify savable +type afterLoadStruct struct { + v int `state:"nosave"` +} + +func (a *afterLoadStruct) afterLoad() { + a.v++ +} + +// +stateify savable +type valueLoadStruct struct { + v int `state:".(int64)"` +} + +func (v *valueLoadStruct) saveV() int64 { + return int64(v.v) // Save as int64. +} + +func (v *valueLoadStruct) loadV(value int64) { + v.v = int(value) // Load as int. +} + +// +stateify savable +type cycleStruct struct { + c *cycleStruct +} + +// +stateify savable +type badCycleStruct struct { + b *badCycleStruct `state:"wait"` +} + +func (b *badCycleStruct) afterLoad() { + if b.b != b { + // This is not executable, since AfterLoad requires that the + // object and all dependencies are complete. This should cause + // a deadlock error during load. + panic("badCycleStruct.afterLoad called") + } +} diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go new file mode 100644 index 000000000..1e9794296 --- /dev/null +++ b/pkg/state/tests/load_test.go @@ -0,0 +1,70 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +func TestLoadHooks(t *testing.T) { + runTestCases(t, false, "load-hooks", []interface{}{ + &afterLoadStruct{v: 1}, + &valueLoadStruct{v: 1}, + &genericContainer{v: &afterLoadStruct{v: 1}}, + &genericContainer{v: &valueLoadStruct{v: 1}}, + &sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + &sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &afterLoadStruct{v: 1}}}, + &mapContainer{v: map[int]interface{}{0: &valueLoadStruct{v: 1}}}, + }) +} + +func TestCycles(t *testing.T) { + // cs is a single object cycle. + cs := cycleStruct{nil} + cs.c = &cs + + // cs1 and cs2 are in a two object cycle. + cs1 := cycleStruct{nil} + cs2 := cycleStruct{nil} + cs1.c = &cs2 + cs2.c = &cs1 + + runTestCases(t, false, "cycles", []interface{}{ + cs, + cs1, + }) +} + +func TestDeadlock(t *testing.T) { + // bs is a single object cycle. This does not cause deadlock because an + // object cannot wait for itself. + bs := badCycleStruct{nil} + bs.b = &bs + + runTestCases(t, false, "self", []interface{}{ + &bs, + }) + + // bs2 and bs2 are in a deadlocking cycle. + bs1 := badCycleStruct{nil} + bs2 := badCycleStruct{nil} + bs1.b = &bs2 + bs2.b = &bs1 + + runTestCases(t, true, "deadlock", []interface{}{ + &bs1, + }) +} diff --git a/pkg/state/tests/map.go b/pkg/state/tests/map.go new file mode 100644 index 000000000..db4e548f1 --- /dev/null +++ b/pkg/state/tests/map.go @@ -0,0 +1,28 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type mapContainer struct { + v map[int]interface{} +} + +// +stateify savable +type mapPtrContainer struct { + v *map[int]interface{} +} + +// +stateify savable +type registeredMapStruct struct{} diff --git a/pkg/state/tests/map_test.go b/pkg/state/tests/map_test.go new file mode 100644 index 000000000..92bf0fc01 --- /dev/null +++ b/pkg/state/tests/map_test.go @@ -0,0 +1,90 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "reflect" + "testing" +) + +var allMapPrimitives = []interface{}{ + bool(true), + int(1), + int8(1), + int16(1), + int32(1), + int64(1), + uint(1), + uintptr(1), + uint8(1), + uint16(1), + uint32(1), + uint64(1), + string(""), + registeredMapStruct{}, +} + +var allMapKeys = flatten(allMapPrimitives, pointersTo(allMapPrimitives)) + +var allMapValues = flatten(allMapPrimitives, pointersTo(allMapPrimitives), interfacesTo(allMapPrimitives)) + +var emptyMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + return m.Interface() +}) + +var fullMaps = combine(allMapKeys, allMapValues, func(v1, v2 interface{}) interface{} { + m := reflect.MakeMap(reflect.MapOf(reflect.TypeOf(v1), reflect.TypeOf(v2))) + m.SetMapIndex(reflect.Zero(reflect.TypeOf(v1)), reflect.Zero(reflect.TypeOf(v2))) + return m.Interface() +}) + +func TestMapAliasing(t *testing.T) { + v := make(map[int]int) + ptrToV := &v + aliases := []map[int]int{v, v} + runTestCases(t, false, "", []interface{}{ptrToV, aliases}) +} + +func TestMapsEmpty(t *testing.T) { + runTestCases(t, false, "plain", emptyMaps) + runTestCases(t, false, "pointers", pointersTo(emptyMaps)) + runTestCases(t, false, "interfaces", interfacesTo(emptyMaps)) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(emptyMaps))) +} + +func TestMapsFull(t *testing.T) { + runTestCases(t, false, "plain", fullMaps) + runTestCases(t, false, "pointers", pointersTo(fullMaps)) + runTestCases(t, false, "interfaces", interfacesTo(fullMaps)) + runTestCases(t, false, "interfacesToPointer", interfacesTo(pointersTo(fullMaps))) +} + +func TestMapContainers(t *testing.T) { + var ( + nilMap map[int]interface{} + emptyMap = make(map[int]interface{}) + fullMap = map[int]interface{}{0: nil} + ) + runTestCases(t, false, "", []interface{}{ + mapContainer{v: nilMap}, + mapContainer{v: emptyMap}, + mapContainer{v: fullMap}, + mapPtrContainer{v: nil}, + mapPtrContainer{v: &nilMap}, + mapPtrContainer{v: &emptyMap}, + mapPtrContainer{v: &fullMap}, + }) +} diff --git a/pkg/state/tests/register.go b/pkg/state/tests/register.go new file mode 100644 index 000000000..074d86315 --- /dev/null +++ b/pkg/state/tests/register.go @@ -0,0 +1,21 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +// +stateify savable +type alreadyRegisteredStruct struct{} + +// +stateify savable +type alreadyRegisteredOther int diff --git a/pkg/state/tests/register_test.go b/pkg/state/tests/register_test.go new file mode 100644 index 000000000..c829753cc --- /dev/null +++ b/pkg/state/tests/register_test.go @@ -0,0 +1,167 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +// faker calls itself whatever is in the name field. +type faker struct { + Name string + Fields []string +} + +func (f *faker) StateTypeName() string { + return f.Name +} + +func (f *faker) StateFields() []string { + return f.Fields +} + +// fakerWithSaverLoader has all it needs. +type fakerWithSaverLoader struct { + faker +} + +func (f *fakerWithSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerWithSaverLoader) StateLoad(m state.Source) {} + +// fakerOther calls itself .. uh, itself? +type fakerOther string + +func (f *fakerOther) StateTypeName() string { + return string(*f) +} + +func (f *fakerOther) StateFields() []string { + return nil +} + +func newFakerOther(name string) *fakerOther { + f := fakerOther(name) + return &f +} + +// fakerOtherBadFields returns non-nil fields. +type fakerOtherBadFields string + +func (f *fakerOtherBadFields) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherBadFields) StateFields() []string { + return []string{string(*f)} +} + +func newFakerOtherBadFields(name string) *fakerOtherBadFields { + f := fakerOtherBadFields(name) + return &f +} + +// fakerOtherSaverLoader implements SaverLoader methods. +type fakerOtherSaverLoader string + +func (f *fakerOtherSaverLoader) StateTypeName() string { + return string(*f) +} + +func (f *fakerOtherSaverLoader) StateFields() []string { + return nil +} + +func (f *fakerOtherSaverLoader) StateSave(m state.Sink) {} + +func (f *fakerOtherSaverLoader) StateLoad(m state.Source) {} + +func newFakerOtherSaverLoader(name string) *fakerOtherSaverLoader { + f := fakerOtherSaverLoader(name) + return &f +} + +func TestRegisterPrimitives(t *testing.T) { + for _, typeName := range []string{ + "int", + "int8", + "int16", + "int32", + "int64", + "uint", + "uintptr", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + "string", + } { + t.Run("struct/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(&faker{ + Name: typeName, + }) + }) + t.Run("other/"+typeName, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering type %q did not panic", typeName) + } + }() + state.Register(newFakerOther(typeName)) + }) + } +} + +func TestRegisterBad(t *testing.T) { + const ( + goodName = "foo" + firstField = "a" + secondField = "b" + ) + for name, object := range map[string]state.Type{ + "non-struct-with-fields": newFakerOtherBadFields(goodName), + "non-struct-with-saverloader": newFakerOtherSaverLoader(goodName), + "struct-without-saverloader": &faker{Name: goodName}, + "non-struct-duplicate-with-struct": newFakerOther((new(alreadyRegisteredStruct)).StateTypeName()), + "non-struct-duplicate-with-non-struct": newFakerOther((new(alreadyRegisteredOther)).StateTypeName()), + "struct-duplicate-with-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredStruct)).StateTypeName()}}, + "struct-duplicate-with-non-struct": &fakerWithSaverLoader{faker{Name: (new(alreadyRegisteredOther)).StateTypeName()}}, + "struct-with-empty-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{""}}}, + "struct-with-empty-field-and-non-empty": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, ""}}}, + "struct-with-duplicate-field": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, firstField}}}, + "struct-with-duplicate-field-and-non-dup": &fakerWithSaverLoader{faker{Name: goodName, Fields: []string{firstField, secondField, firstField}}}, + } { + t.Run(name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Registering object %#v did not panic", object) + } + }() + state.Register(object) + }) + + } +} diff --git a/pkg/state/tests/string_test.go b/pkg/state/tests/string_test.go new file mode 100644 index 000000000..44f5a562c --- /dev/null +++ b/pkg/state/tests/string_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" +) + +const nonEmptyString = "hello world" + +var allStrings = []string{ + "", + nonEmptyString, + "\\0", +} + +func TestString(t *testing.T) { + runTestCases(t, false, "plain", flatten(allStrings)) + runTestCases(t, false, "pointers", pointersTo(flatten(allStrings))) + runTestCases(t, false, "interfaces", interfacesTo(flatten(allStrings))) + runTestCases(t, false, "interfacesToPointers", interfacesTo(pointersTo(flatten(allStrings)))) +} diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go new file mode 100644 index 000000000..bd2c2b399 --- /dev/null +++ b/pkg/state/tests/struct.go @@ -0,0 +1,65 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +type unregisteredEmptyStruct struct{} + +// typeOnlyEmptyStruct just implements the state.Type interface. +type typeOnlyEmptyStruct struct{} + +func (*typeOnlyEmptyStruct) StateTypeName() string { return "registeredEmptyStruct" } + +func (*typeOnlyEmptyStruct) StateFields() []string { return nil } + +// +stateify savable +type savableEmptyStruct struct{} + +// +stateify savable +type emptyStructPointer struct { + nothing *struct{} +} + +// +stateify savable +type outerSame struct { + inner inner +} + +// +stateify savable +type outerFieldFirst struct { + inner inner + v int64 +} + +// +stateify savable +type outerFieldSecond struct { + v int64 + inner inner +} + +// +stateify savable +type outerArray struct { + inner [2]inner +} + +// +stateify savable +type inner struct { + v int64 +} + +// +stateify savable +type system struct { + v1 interface{} + v2 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go new file mode 100644 index 000000000..de9d17aa7 --- /dev/null +++ b/pkg/state/tests/struct_test.go @@ -0,0 +1,89 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tests + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/state" +) + +func TestEmptyStruct(t *testing.T) { + runTestCases(t, false, "plain", []interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + }) + runTestCases(t, false, "pointers", pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + savableEmptyStruct{}, + })) + runTestCases(t, false, "interfaces-pass", interfacesTo([]interface{}{ + // Only registered types can be dispatched via interfaces. All + // other types should fail, even if it is the empty struct. + savableEmptyStruct{}, + })) + runTestCases(t, true, "interfaces-fail", interfacesTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + })) + runTestCases(t, false, "interfacesToPointers-pass", interfacesTo(pointersTo([]interface{}{ + savableEmptyStruct{}, + }))) + runTestCases(t, true, "interfacesToPointers-fail", interfacesTo(pointersTo([]interface{}{ + unregisteredEmptyStruct{}, + typeOnlyEmptyStruct{}, + }))) + + // Ensuring empty struct aliasing works. + es := emptyStructPointer{new(struct{})} + runTestCases(t, false, "empty-struct-pointers", []interface{}{ + emptyStructPointer{}, + es, + []emptyStructPointer{es, es}, // Same pointer. + }) +} + +func TestRegisterTypeOnlyStruct(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Register did not panic") + } + }() + state.Register((*typeOnlyEmptyStruct)(nil)) +} + +func TestEmbeddedPointers(t *testing.T) { + var ( + ofs outerSame + of1 outerFieldFirst + of2 outerFieldSecond + oa outerArray + ) + + runTestCases(t, false, "embedded-pointers", []interface{}{ + system{&ofs, &ofs.inner}, + system{&ofs.inner, &ofs}, + system{&of1, &of1.inner}, + system{&of1.inner, &of1}, + system{&of2, &of2.inner}, + system{&of2.inner, &of2}, + system{&oa, &oa.inner[0]}, + system{&oa, &oa.inner[1]}, + system{&oa.inner[0], &oa}, + system{&oa.inner[1], &oa}, + }) +} diff --git a/pkg/state/tests/tests.go b/pkg/state/tests/tests.go new file mode 100644 index 000000000..435a0e9db --- /dev/null +++ b/pkg/state/tests/tests.go @@ -0,0 +1,215 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tests tests the state packages. +package tests + +import ( + "bytes" + "context" + "fmt" + "math" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/pretty" +) + +// discard is an implementation of wire.Writer. +type discard struct{} + +// Write implements wire.Writer.Write. +func (discard) Write(p []byte) (int, error) { return len(p), nil } + +// WriteByte implements wire.Writer.WriteByte. +func (discard) WriteByte(byte) error { return nil } + +// checkEqual checks if two objects are equal. +// +// N.B. This only handles one level of dereferences for NaN. Otherwise we +// would need to fork the entire implementation of reflect.DeepEqual. +func checkEqual(root, loadedValue interface{}) bool { + if reflect.DeepEqual(root, loadedValue) { + return true + } + + // NaN is not equal to itself. We handle the case of raw floating point + // primitives here, but don't handle this case nested. + rf32, ok1 := root.(float32) + lf32, ok2 := loadedValue.(float32) + if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) { + return true + } + rf64, ok1 := root.(float64) + lf64, ok2 := loadedValue.(float64) + if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) { + return true + } + + // Same real for complex numbers. + rc64, ok1 := root.(complex64) + lc64, ok2 := root.(complex64) + if ok1 && ok2 { + return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64)) + } + rc128, ok1 := root.(complex128) + lc128, ok2 := root.(complex128) + if ok1 && ok2 { + return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128)) + } + + return false +} + +// runTestCases runs a test for each object in objects. +func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) { + t.Helper() + for i, root := range objects { + t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) { + t.Logf("Original object:\n%#v", root) + + // Save the passed object. + saveBuffer := &bytes.Buffer{} + saveObjectPtr := reflect.New(reflect.TypeOf(root)) + saveObjectPtr.Elem().Set(reflect.ValueOf(root)) + saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Save failed unexpectedly: %v", err) + } + + // Dump the serialized proto to aid with debugging. + var ppBuf bytes.Buffer + t.Logf("Raw state:\n%v", saveBuffer.Bytes()) + if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // We don't count this as a test failure if we + // have shouldFail set, but we will count as a + // failure if we were not expecting to fail. + if !shouldFail { + t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err) + } + } + if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil { + // See above. + if !shouldFail { + t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err) + } + } + t.Logf("Encoded state:\n%s", ppBuf.String()) + t.Logf("Save stats:\n%s", saveStats.String()) + + // Load a new copy of the object. + loadObjectPtr := reflect.New(reflect.TypeOf(root)) + loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface()) + if err != nil { + if shouldFail { + return + } + t.Fatalf("Load failed unexpectedly: %v", err) + } + + // Compare the values. + loadedValue := loadObjectPtr.Elem().Interface() + if !checkEqual(root, loadedValue) { + if shouldFail { + return + } + t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue) + } + + // Everything went okay. Is that good? + if shouldFail { + t.Fatalf("This test was expected to fail, but didn't.") + } + t.Logf("Load stats:\n%s", loadStats.String()) + + // Truncate half the bytes in the byte stream, + // and ensure that we can't restore. Then + // truncate only the final byte and ensure that + // we can't restore. + l := saveBuffer.Len() + halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2]) + if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with half bytes succeeded unexpectedly.") + } + missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1]) + if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil { + t.Errorf("Load with missing byte succeeded unexpectedly.") + } + }) + } +} + +// convert converts the slice to an []interface{}. +func convert(v interface{}) (r []interface{}) { + s := reflect.ValueOf(v) // Must be slice. + for i := 0; i < s.Len(); i++ { + r = append(r, s.Index(i).Interface()) + } + return r +} + +// flatten flattens multiple slices. +func flatten(vs ...interface{}) (r []interface{}) { + for _, v := range vs { + r = append(r, convert(v)...) + } + return r +} + +// filter maps from one slice to another. +func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) { + s := reflect.ValueOf(vs) + for i := 0; i < s.Len(); i++ { + v, ok := fn(s.Index(i).Interface()) + if ok { + r = append(r, v) + } + } + return r +} + +// combine combines objects in two slices as specified. +func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) { + s1 := reflect.ValueOf(v1) + s2 := reflect.ValueOf(v2) + for i := 0; i < s1.Len(); i++ { + for j := 0; j < s2.Len(); j++ { + // Combine using the given function. + r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface())) + } + } + return r +} + +// pointersTo is a filter function that returns pointers. +func pointersTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + v := reflect.New(reflect.TypeOf(o)) + v.Elem().Set(reflect.ValueOf(o)) + return v.Interface(), true + }) +} + +// interfacesTo is a filter function that returns interface objects. +func interfacesTo(vs interface{}) []interface{} { + return filter(vs, func(o interface{}) (interface{}, bool) { + var v [1]interface{} + v[0] = o + return v, true + }) +} diff --git a/pkg/state/types.go b/pkg/state/types.go new file mode 100644 index 000000000..215ef80f8 --- /dev/null +++ b/pkg/state/types.go @@ -0,0 +1,361 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "reflect" + "sort" + + "gvisor.dev/gvisor/pkg/state/wire" +) + +// assertValidType asserts that the type is valid. +func assertValidType(name string, fields []string) { + if name == "" { + Failf("type has empty name") + } + fieldsCopy := make([]string, len(fields)) + for i := 0; i < len(fields); i++ { + if fields[i] == "" { + Failf("field has empty name for type %q", name) + } + fieldsCopy[i] = fields[i] + } + sort.Slice(fieldsCopy, func(i, j int) bool { + return fieldsCopy[i] < fieldsCopy[j] + }) + for i := range fieldsCopy { + if i > 0 && fieldsCopy[i-1] == fieldsCopy[i] { + Failf("duplicate field %q for type %s", fieldsCopy[i], name) + } + } +} + +// typeEntry is an entry in the typeDatabase. +type typeEntry struct { + ID typeID + wire.Type +} + +// reconciledTypeEntry is a reconciled entry in the typeDatabase. +type reconciledTypeEntry struct { + wire.Type + LocalType reflect.Type + FieldOrder []int +} + +// typeEncodeDatabase is an internal TypeInfo database for encoding. +type typeEncodeDatabase struct { + // byType maps by type to the typeEntry. + byType map[reflect.Type]*typeEntry + + // lastID is the last used ID. + lastID typeID +} + +// makeTypeEncodeDatabase makes a typeDatabase. +func makeTypeEncodeDatabase() typeEncodeDatabase { + return typeEncodeDatabase{ + byType: make(map[reflect.Type]*typeEntry), + } +} + +// typeDecodeDatabase is an internal TypeInfo database for decoding. +type typeDecodeDatabase struct { + // byID maps by ID to type. + byID []*reconciledTypeEntry + + // pending are entries that are pending validation by Lookup. These + // will be reconciled with actual objects. Note that these will also be + // used to lookup types by name, since they may not be reconciled and + // there's little value to deleting from this map. + pending []*wire.Type +} + +// makeTypeDecodeDatabase makes a typeDatabase. +func makeTypeDecodeDatabase() typeDecodeDatabase { + return typeDecodeDatabase{} +} + +// lookupNameFields extracts the name and fields from an object. +func lookupNameFields(typ reflect.Type) (string, []string, bool) { + v := reflect.Zero(reflect.PtrTo(typ)).Interface() + t, ok := v.(Type) + if !ok { + // Is this a primitive? + if typ.Kind() == reflect.Interface { + return interfaceType, nil, true + } + name := typ.Name() + if _, ok := primitiveTypeDatabase[name]; !ok { + // This is not a known type, and not a primitive. The + // encoder may proceed for anonymous empty structs, or + // it may deference the type pointer and try again. + return "", nil, false + } + return name, nil, true + } + // Extract the name from the object. + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + return name, fields, true +} + +// Lookup looks up or registers the given object. +// +// The bool indicates whether this is an existing entry: false means the entry +// did not exist, and true means the entry did exist. If this bool is false and +// the returned typeEntry are nil, then the obj did not implement the Type +// interface. +func (tdb *typeEncodeDatabase) Lookup(typ reflect.Type) (*typeEntry, bool) { + te, ok := tdb.byType[typ] + if !ok { + // Lookup the type information. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs may still be encoded, so let the + // caller decide what to do from here. + return nil, false + } + + // Register the new type. + tdb.lastID++ + te = &typeEntry{ + ID: tdb.lastID, + Type: wire.Type{ + Name: name, + Fields: fields, + }, + } + + // All done. + tdb.byType[typ] = te + return te, false + } + return te, true +} + +// Register adds a typeID entry. +func (tbd *typeDecodeDatabase) Register(typ *wire.Type) { + assertValidType(typ.Name, typ.Fields) + tbd.pending = append(tbd.pending, typ) +} + +// LookupName looks up the type name by ID. +func (tbd *typeDecodeDatabase) LookupName(id typeID) string { + if len(tbd.pending) < int(id) { + // This is likely an encoder error? + Failf("type ID %d not available", id) + } + return tbd.pending[id-1].Name +} + +// LookupType looks up the type by ID. +func (tbd *typeDecodeDatabase) LookupType(id typeID) reflect.Type { + name := tbd.LookupName(id) + typ, ok := globalTypeDatabase[name] + if !ok { + // If not available, see if it's primitive. + typ, ok = primitiveTypeDatabase[name] + if !ok && name == interfaceType { + // Matches the built-in interface type. + var i interface{} + return reflect.TypeOf(&i).Elem() + } + if !ok { + // The type is perhaps not registered? + Failf("type name %q is not available", name) + } + return typ // Primitive type. + } + return typ // Registered type. +} + +// singleFieldOrder defines the field order for a single field. +var singleFieldOrder = []int{0} + +// Lookup looks up or registers the given object. +// +// First, the typeID is searched to see if this has already been appropriately +// reconciled. If no, then a reconcilation will take place that may result in a +// field ordering. If a nil reconciledTypeEntry is returned from this method, +// then the object does not support the Type interface. +// +// This method never returns nil. +func (tbd *typeDecodeDatabase) Lookup(id typeID, typ reflect.Type) *reconciledTypeEntry { + if len(tbd.byID) > int(id) && tbd.byID[id-1] != nil { + // Already reconciled. + return tbd.byID[id-1] + } + // The ID has not been reconciled yet. That's fine. We need to make + // sure it aligns with the current provided object. + if len(tbd.pending) < int(id) { + // This id was never registered. Probably an encoder error? + Failf("typeDatabase does not contain id %d", id) + } + // Extract the pending info. + pending := tbd.pending[id-1] + // Grow the byID list. + if len(tbd.byID) < int(id) { + tbd.byID = append(tbd.byID, make([]*reconciledTypeEntry, int(id)-len(tbd.byID))...) + } + // Reconcile the type. + name, fields, ok := lookupNameFields(typ) + if !ok { + // Empty structs are decoded only when the type is nil. Since + // this isn't the case, we fail here. + Failf("unsupported type %q during decode; can't reconcile", pending.Name) + } + if name != pending.Name { + // Are these the same type? Print a helpful message as this may + // actually happen in practice if types change. + Failf("typeDatabase contains conflicting definitions for id %d: %s->%v (current) and %s->%v (existing)", + id, name, fields, pending.Name, pending.Fields) + } + rte := &reconciledTypeEntry{ + Type: wire.Type{ + Name: name, + Fields: fields, + }, + LocalType: typ, + } + // If there are zero or one fields, then we skip allocating the field + // slice. There is special handling for decoding in this case. If the + // field name does not match, it will be caught in the general purpose + // code below. + if len(fields) != len(pending.Fields) { + Failf("type %q contains different fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + if len(fields) == 0 { + tbd.byID[id-1] = rte // Save. + return rte + } + if len(fields) == 1 && fields[0] == pending.Fields[0] { + tbd.byID[id-1] = rte // Save. + rte.FieldOrder = singleFieldOrder + return rte + } + // For each field in the current object's information, match it to a + // field in the destination object. We know from the assertion above + // and the insertion on insertion to pending that neither field + // contains any duplicates. + fieldOrder := make([]int, len(fields)) + for i, name := range fields { + fieldOrder[i] = -1 // Sentinel. + // Is it an exact match? + if pending.Fields[i] == name { + fieldOrder[i] = i + continue + } + // Find the matching field. + for j, otherName := range pending.Fields { + if name == otherName { + fieldOrder[i] = j + break + } + } + if fieldOrder[i] == -1 { + // The type name matches but we are lacking some common fields. + Failf("type %q has mismatched fields: %v (decode) and %v (encode)", + name, fields, pending.Fields) + } + } + // The type has been reeconciled. + rte.FieldOrder = fieldOrder + tbd.byID[id-1] = rte + return rte +} + +// interfaceType defines all interfaces. +const interfaceType = "interface" + +// primitiveTypeDatabase is a set of fixed types. +var primitiveTypeDatabase = func() map[string]reflect.Type { + r := make(map[string]reflect.Type) + for _, t := range []reflect.Type{ + reflect.TypeOf(false), + reflect.TypeOf(int(0)), + reflect.TypeOf(int8(0)), + reflect.TypeOf(int16(0)), + reflect.TypeOf(int32(0)), + reflect.TypeOf(int64(0)), + reflect.TypeOf(uint(0)), + reflect.TypeOf(uintptr(0)), + reflect.TypeOf(uint8(0)), + reflect.TypeOf(uint16(0)), + reflect.TypeOf(uint32(0)), + reflect.TypeOf(uint64(0)), + reflect.TypeOf(""), + reflect.TypeOf(float32(0.0)), + reflect.TypeOf(float64(0.0)), + reflect.TypeOf(complex64(0.0)), + reflect.TypeOf(complex128(0.0)), + } { + r[t.Name()] = t + } + return r +}() + +// globalTypeDatabase is used for dispatching interfaces on decode. +var globalTypeDatabase = map[string]reflect.Type{} + +// Register registers a type. +// +// This must be called on init and only done once. +func Register(t Type) { + name := t.StateTypeName() + fields := t.StateFields() + assertValidType(name, fields) + // Register must always be called on pointers. + typ := reflect.TypeOf(t) + if typ.Kind() != reflect.Ptr { + Failf("Register must be called on pointers") + } + typ = typ.Elem() + if typ.Kind() == reflect.Struct { + // All registered structs must implement SaverLoader. We allow + // the registration is non-struct types with just the Type + // interface, but we need to call StateSave/StateLoad methods + // on aggregate types. + if _, ok := t.(SaverLoader); !ok { + Failf("struct %T does not implement SaverLoader", t) + } + } else { + // Non-structs must not have any fields. We don't support + // calling StateSave/StateLoad methods on any non-struct types. + // If custom behavior is required, these types should be + // wrapped in a structure of some kind. + if len(fields) != 0 { + Failf("non-struct %T has non-zero fields %v", t, fields) + } + // We don't allow non-structs to implement StateSave/StateLoad + // methods, because they won't be called and it's confusing. + if _, ok := t.(SaverLoader); ok { + Failf("non-struct %T implements SaverLoader", t) + } + } + if _, ok := primitiveTypeDatabase[name]; ok { + Failf("conflicting primitiveTypeDatabase entry for %T: used by primitive", t) + } + if _, ok := globalTypeDatabase[name]; ok { + Failf("conflicting globalTypeDatabase entries for %T: name conflict", t) + } + if name == interfaceType { + Failf("conflicting name for %T: matches interfaceType", t) + } + globalTypeDatabase[name] = typ +} diff --git a/pkg/state/wire/BUILD b/pkg/state/wire/BUILD new file mode 100644 index 000000000..311b93dcb --- /dev/null +++ b/pkg/state/wire/BUILD @@ -0,0 +1,12 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "wire", + srcs = ["wire.go"], + marshal = False, + stateify = False, + visibility = ["//:sandbox"], + deps = ["//pkg/gohacks"], +) diff --git a/pkg/state/wire/wire.go b/pkg/state/wire/wire.go new file mode 100644 index 000000000..93dee6740 --- /dev/null +++ b/pkg/state/wire/wire.go @@ -0,0 +1,970 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package wire contains a few basic types that can be composed to serialize +// graph information for the state package. This package defines the wire +// protocol. +// +// Note that these types are careful about how they implement the relevant +// interfaces (either value receiver or pointer receiver), so that native-sized +// types, such as integers and simple pointers, can fit inside the interface +// object. +// +// This package also uses panic as control flow, so called should be careful to +// wrap calls in appropriate handlers. +// +// Testing for this package is driven by the state test package. +package wire + +import ( + "fmt" + "io" + "math" + + "gvisor.dev/gvisor/pkg/gohacks" +) + +// Reader is the required reader interface. +type Reader interface { + io.Reader + ReadByte() (byte, error) +} + +// Writer is the required writer interface. +type Writer interface { + io.Writer + WriteByte(byte) error +} + +// readFull is a utility. The equivalent is not needed for Write, but the API +// contract dictates that it must always complete all bytes given or return an +// error. +func readFull(r io.Reader, p []byte) { + for done := 0; done < len(p); { + n, err := r.Read(p[done:]) + done += n + if n == 0 && err != nil { + panic(err) + } + } +} + +// Object is a generic object. +type Object interface { + // save saves the given object. + // + // Panic is used for error control flow. + save(Writer) + + // load loads a new object of the given type. + // + // Panic is used for error control flow. + load(Reader) Object +} + +// Bool is a boolean. +type Bool bool + +// loadBool loads an object of type Bool. +func loadBool(r Reader) Bool { + b := loadUint(r) + return Bool(b == 1) +} + +// save implements Object.save. +func (b Bool) save(w Writer) { + var v Uint + if b { + v = 1 + } else { + v = 0 + } + v.save(w) +} + +// load implements Object.load. +func (Bool) load(r Reader) Object { return loadBool(r) } + +// Int is a signed integer. +// +// This uses varint encoding. +type Int int64 + +// loadInt loads an object of type Int. +func loadInt(r Reader) Int { + u := loadUint(r) + x := Int(u >> 1) + if u&1 != 0 { + x = ^x + } + return x +} + +// save implements Object.save. +func (i Int) save(w Writer) { + u := Uint(i) << 1 + if i < 0 { + u = ^u + } + u.save(w) +} + +// load implements Object.load. +func (Int) load(r Reader) Object { return loadInt(r) } + +// Uint is an unsigned integer. +type Uint uint64 + +// loadUint loads an object of type Uint. +func loadUint(r Reader) Uint { + var ( + u Uint + s uint + ) + for i := 0; i <= 9; i++ { + b, err := r.ReadByte() + if err != nil { + panic(err) + } + if b < 0x80 { + if i == 9 && b > 1 { + panic("overflow") + } + u |= Uint(b) << s + return u + } + u |= Uint(b&0x7f) << s + s += 7 + } + panic("unreachable") +} + +// save implements Object.save. +func (u Uint) save(w Writer) { + for u >= 0x80 { + if err := w.WriteByte(byte(u) | 0x80); err != nil { + panic(err) + } + u >>= 7 + } + if err := w.WriteByte(byte(u)); err != nil { + panic(err) + } +} + +// load implements Object.load. +func (Uint) load(r Reader) Object { return loadUint(r) } + +// Float32 is a 32-bit floating point number. +type Float32 float32 + +// loadFloat32 loads an object of type Float32. +func loadFloat32(r Reader) Float32 { + n := loadUint(r) + return Float32(math.Float32frombits(uint32(n))) +} + +// save implements Object.save. +func (f Float32) save(w Writer) { + n := Uint(math.Float32bits(float32(f))) + n.save(w) +} + +// load implements Object.load. +func (Float32) load(r Reader) Object { return loadFloat32(r) } + +// Float64 is a 64-bit floating point number. +type Float64 float64 + +// loadFloat64 loads an object of type Float64. +func loadFloat64(r Reader) Float64 { + n := loadUint(r) + return Float64(math.Float64frombits(uint64(n))) +} + +// save implements Object.save. +func (f Float64) save(w Writer) { + n := Uint(math.Float64bits(float64(f))) + n.save(w) +} + +// load implements Object.load. +func (Float64) load(r Reader) Object { return loadFloat64(r) } + +// Complex64 is a 64-bit complex number. +type Complex64 complex128 + +// loadComplex64 loads an object of type Complex64. +func loadComplex64(r Reader) Complex64 { + re := loadFloat32(r) + im := loadFloat32(r) + return Complex64(complex(float32(re), float32(im))) +} + +// save implements Object.save. +func (c *Complex64) save(w Writer) { + re := Float32(real(*c)) + im := Float32(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex64) load(r Reader) Object { + c := loadComplex64(r) + return &c +} + +// Complex128 is a 128-bit complex number. +type Complex128 complex128 + +// loadComplex128 loads an object of type Complex128. +func loadComplex128(r Reader) Complex128 { + re := loadFloat64(r) + im := loadFloat64(r) + return Complex128(complex(float64(re), float64(im))) +} + +// save implements Object.save. +func (c *Complex128) save(w Writer) { + re := Float64(real(*c)) + im := Float64(imag(*c)) + re.save(w) + im.save(w) +} + +// load implements Object.load. +func (*Complex128) load(r Reader) Object { + c := loadComplex128(r) + return &c +} + +// String is a string. +type String string + +// loadString loads an object of type String. +func loadString(r Reader) String { + l := loadUint(r) + p := make([]byte, l) + readFull(r, p) + return String(gohacks.StringFromImmutableBytes(p)) +} + +// save implements Object.save. +func (s *String) save(w Writer) { + l := Uint(len(*s)) + l.save(w) + p := gohacks.ImmutableBytesFromString(string(*s)) + _, err := w.Write(p) // Must write all bytes. + if err != nil { + panic(err) + } +} + +// load implements Object.load. +func (*String) load(r Reader) Object { + s := loadString(r) + return &s +} + +// Dot is a kind of reference: one of Index and FieldName. +type Dot interface { + isDot() +} + +// Index is a reference resolution. +type Index uint32 + +func (Index) isDot() {} + +// FieldName is a reference resolution. +type FieldName string + +func (*FieldName) isDot() {} + +// Ref is a reference to an object. +type Ref struct { + // Root is the root object. + Root Uint + + // Dots is the set of traversals required from the Root object above. + // Note that this will be stored in reverse order for efficiency. + Dots []Dot + + // Type is the base type for the root object. This is non-nil iff Dots + // is non-zero length (that is, this is a complex reference). This is + // not *strictly* necessary, but can be used to simplify decoding. + Type TypeSpec +} + +// loadRef loads an object of type Ref (abstract). +func loadRef(r Reader) Ref { + ref := Ref{ + Root: loadUint(r), + } + l := loadUint(r) + ref.Dots = make([]Dot, l) + for i := 0; i < int(l); i++ { + // Disambiguate between an Index (non-negative) and a field + // name (negative). This does some space and avoids a dedicate + // loadDot function. See Ref.save for the other side. + d := loadInt(r) + if d >= 0 { + ref.Dots[i] = Index(d) + continue + } + p := make([]byte, -d) + readFull(r, p) + fieldName := FieldName(gohacks.StringFromImmutableBytes(p)) + ref.Dots[i] = &fieldName + } + if l != 0 { + // Only if dots is non-zero. + ref.Type = loadTypeSpec(r) + } + return ref +} + +// save implements Object.save. +func (r *Ref) save(w Writer) { + r.Root.save(w) + l := Uint(len(r.Dots)) + l.save(w) + for _, d := range r.Dots { + // See LoadRef. We use non-negative numbers to encode Index + // objects and negative numbers to encode field lengths. + switch x := d.(type) { + case Index: + i := Int(x) + i.save(w) + case *FieldName: + d := Int(-len(*x)) + d.save(w) + p := gohacks.ImmutableBytesFromString(string(*x)) + if _, err := w.Write(p); err != nil { + panic(err) + } + default: + panic("unknown dot implementation") + } + } + if l != 0 { + // See above. + saveTypeSpec(w, r.Type) + } +} + +// load implements Object.load. +func (*Ref) load(r Reader) Object { + ref := loadRef(r) + return &ref +} + +// Nil is a primitive zero value of any type. +type Nil struct{} + +// loadNil loads an object of type Nil. +func loadNil(r Reader) Nil { + return Nil{} +} + +// save implements Object.save. +func (Nil) save(w Writer) {} + +// load implements Object.load. +func (Nil) load(r Reader) Object { return loadNil(r) } + +// Slice is a slice value. +type Slice struct { + Length Uint + Capacity Uint + Ref Ref +} + +// loadSlice loads an object of type Slice. +func loadSlice(r Reader) Slice { + return Slice{ + Length: loadUint(r), + Capacity: loadUint(r), + Ref: loadRef(r), + } +} + +// save implements Object.save. +func (s *Slice) save(w Writer) { + s.Length.save(w) + s.Capacity.save(w) + s.Ref.save(w) +} + +// load implements Object.load. +func (*Slice) load(r Reader) Object { + s := loadSlice(r) + return &s +} + +// Array is an array value. +type Array struct { + Contents []Object +} + +// loadArray loads an object of type Array. +func loadArray(r Reader) Array { + l := loadUint(r) + if l == 0 { + // Note that there isn't a single object available to encode + // the type of, so we need this additional branch. + return Array{} + } + // All the objects here have the same type, so use dynamic dispatch + // only once. All other objects will automatically take the same type + // as the first object. + contents := make([]Object, l) + v := Load(r) + contents[0] = v + for i := 1; i < int(l); i++ { + contents[i] = v.load(r) + } + return Array{ + Contents: contents, + } +} + +// save implements Object.save. +func (a *Array) save(w Writer) { + l := Uint(len(a.Contents)) + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, a.Contents[0]) + for i := 1; i < int(l); i++ { + a.Contents[i].save(w) + } +} + +// load implements Object.load. +func (*Array) load(r Reader) Object { + a := loadArray(r) + return &a +} + +// Map is a map value. +type Map struct { + Keys []Object + Values []Object +} + +// loadMap loads an object of type Map. +func loadMap(r Reader) Map { + l := loadUint(r) + if l == 0 { + // See LoadArray. + return Map{} + } + // See type dispatch notes in Array. + keys := make([]Object, l) + values := make([]Object, l) + k := Load(r) + v := Load(r) + keys[0] = k + values[0] = v + for i := 1; i < int(l); i++ { + keys[i] = k.load(r) + values[i] = v.load(r) + } + return Map{ + Keys: keys, + Values: values, + } +} + +// save implements Object.save. +func (m *Map) save(w Writer) { + l := Uint(len(m.Keys)) + if int(l) != len(m.Values) { + panic(fmt.Sprintf("mismatched keys (%d) Aand values (%d)", len(m.Keys), len(m.Values))) + } + l.save(w) + if l == 0 { + // See LoadArray. + return + } + // See above. + Save(w, m.Keys[0]) + Save(w, m.Values[0]) + for i := 1; i < int(l); i++ { + m.Keys[i].save(w) + m.Values[i].save(w) + } +} + +// load implements Object.load. +func (*Map) load(r Reader) Object { + m := loadMap(r) + return &m +} + +// TypeSpec is a type dereference. +type TypeSpec interface { + isTypeSpec() +} + +// TypeID is a concrete type ID. +type TypeID Uint + +func (TypeID) isTypeSpec() {} + +// TypeSpecPointer is a pointer type. +type TypeSpecPointer struct { + Type TypeSpec +} + +func (*TypeSpecPointer) isTypeSpec() {} + +// TypeSpecArray is an array type. +type TypeSpecArray struct { + Count Uint + Type TypeSpec +} + +func (*TypeSpecArray) isTypeSpec() {} + +// TypeSpecSlice is a slice type. +type TypeSpecSlice struct { + Type TypeSpec +} + +func (*TypeSpecSlice) isTypeSpec() {} + +// TypeSpecMap is a map type. +type TypeSpecMap struct { + Key TypeSpec + Value TypeSpec +} + +func (*TypeSpecMap) isTypeSpec() {} + +// TypeSpecNil is an empty type. +type TypeSpecNil struct{} + +func (TypeSpecNil) isTypeSpec() {} + +// TypeSpec types. +// +// These use a distinct encoding on the wire, as they are used only in the +// interface object. They are decoded through the dedicated loadTypeSpec and +// saveTypeSpec functions. +const ( + typeSpecTypeID Uint = iota + typeSpecPointer + typeSpecArray + typeSpecSlice + typeSpecMap + typeSpecNil +) + +// loadTypeSpec loads TypeSpec values. +func loadTypeSpec(r Reader) TypeSpec { + switch hdr := loadUint(r); hdr { + case typeSpecTypeID: + return TypeID(loadUint(r)) + case typeSpecPointer: + return &TypeSpecPointer{ + Type: loadTypeSpec(r), + } + case typeSpecArray: + return &TypeSpecArray{ + Count: loadUint(r), + Type: loadTypeSpec(r), + } + case typeSpecSlice: + return &TypeSpecSlice{ + Type: loadTypeSpec(r), + } + case typeSpecMap: + return &TypeSpecMap{ + Key: loadTypeSpec(r), + Value: loadTypeSpec(r), + } + case typeSpecNil: + return TypeSpecNil{} + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// saveTypeSpec saves TypeSpec values. +func saveTypeSpec(w Writer, t TypeSpec) { + switch x := t.(type) { + case TypeID: + typeSpecTypeID.save(w) + Uint(x).save(w) + case *TypeSpecPointer: + typeSpecPointer.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecArray: + typeSpecArray.save(w) + x.Count.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecSlice: + typeSpecSlice.save(w) + saveTypeSpec(w, x.Type) + case *TypeSpecMap: + typeSpecMap.save(w) + saveTypeSpec(w, x.Key) + saveTypeSpec(w, x.Value) + case TypeSpecNil: + typeSpecNil.save(w) + default: + // This should not happen? + panic(fmt.Errorf("unknown type %T", t)) + } +} + +// Interface is an interface value. +type Interface struct { + Type TypeSpec + Value Object +} + +// loadInterface loads an object of type Interface. +func loadInterface(r Reader) Interface { + return Interface{ + Type: loadTypeSpec(r), + Value: Load(r), + } +} + +// save implements Object.save. +func (i *Interface) save(w Writer) { + saveTypeSpec(w, i.Type) + Save(w, i.Value) +} + +// load implements Object.load. +func (*Interface) load(r Reader) Object { + i := loadInterface(r) + return &i +} + +// Type is type information. +type Type struct { + Name string + Fields []string +} + +// loadType loads an object of type Type. +func loadType(r Reader) Type { + name := string(loadString(r)) + l := loadUint(r) + fields := make([]string, l) + for i := 0; i < int(l); i++ { + fields[i] = string(loadString(r)) + } + return Type{ + Name: name, + Fields: fields, + } +} + +// save implements Object.save. +func (t *Type) save(w Writer) { + s := String(t.Name) + s.save(w) + l := Uint(len(t.Fields)) + l.save(w) + for i := 0; i < int(l); i++ { + s := String(t.Fields[i]) + s.save(w) + } +} + +// load implements Object.load. +func (*Type) load(r Reader) Object { + t := loadType(r) + return &t +} + +// multipleObjects is a special type for serializing multiple objects. +type multipleObjects []Object + +// loadMultipleObjects loads a series of objects. +func loadMultipleObjects(r Reader) multipleObjects { + l := loadUint(r) + m := make(multipleObjects, l) + for i := 0; i < int(l); i++ { + m[i] = Load(r) + } + return m +} + +// save implements Object.save. +func (m *multipleObjects) save(w Writer) { + l := Uint(len(*m)) + l.save(w) + for i := 0; i < int(l); i++ { + Save(w, (*m)[i]) + } +} + +// load implements Object.load. +func (*multipleObjects) load(r Reader) Object { + m := loadMultipleObjects(r) + return &m +} + +// noObjects represents no objects. +type noObjects struct{} + +// loadNoObjects loads a sentinel. +func loadNoObjects(r Reader) noObjects { return noObjects{} } + +// save implements Object.save. +func (noObjects) save(w Writer) {} + +// load implements Object.load. +func (noObjects) load(r Reader) Object { return loadNoObjects(r) } + +// Struct is a basic composite value. +type Struct struct { + TypeID TypeID + fields Object // Optionally noObjects or *multipleObjects. +} + +// Field returns a pointer to the given field slot. +// +// This must be called after Alloc. +func (s *Struct) Field(i int) *Object { + if fields, ok := s.fields.(*multipleObjects); ok { + return &((*fields)[i]) + } + if _, ok := s.fields.(noObjects); ok { + // Alloc may be optionally called; can't call twice. + panic("Field called inappropriately, wrong Alloc?") + } + return &s.fields +} + +// Alloc allocates the given number of fields. +// +// This must be called before Add and Save. +// +// Precondition: slots must be positive. +func (s *Struct) Alloc(slots int) { + switch { + case slots == 0: + s.fields = noObjects{} + case slots == 1: + // Leave it alone. + case slots > 1: + fields := make(multipleObjects, slots) + s.fields = &fields + default: + // Violates precondition. + panic(fmt.Sprintf("Alloc called with negative slots %d?", slots)) + } +} + +// Fields returns the number of fields. +func (s *Struct) Fields() int { + switch x := s.fields.(type) { + case *multipleObjects: + return len(*x) + case noObjects: + return 0 + default: + return 1 + } +} + +// loadStruct loads an object of type Struct. +func loadStruct(r Reader) Struct { + return Struct{ + TypeID: TypeID(loadUint(r)), + fields: Load(r), + } +} + +// save implements Object.save. +// +// Precondition: Alloc must have been called, and the fields all filled in +// appropriately. See Alloc and Add for more details. +func (s *Struct) save(w Writer) { + Uint(s.TypeID).save(w) + Save(w, s.fields) +} + +// load implements Object.load. +func (*Struct) load(r Reader) Object { + s := loadStruct(r) + return &s +} + +// Object types. +// +// N.B. Be careful about changing the order or introducing new elements in the +// middle here. This is part of the wire format and shouldn't change. +const ( + typeBool Uint = iota + typeInt + typeUint + typeFloat32 + typeFloat64 + typeNil + typeRef + typeString + typeSlice + typeArray + typeMap + typeStruct + typeNoObjects + typeMultipleObjects + typeInterface + typeComplex64 + typeComplex128 + typeType +) + +// Save saves the given object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Save(w Writer, obj Object) { + switch x := obj.(type) { + case Bool: + typeBool.save(w) + x.save(w) + case Int: + typeInt.save(w) + x.save(w) + case Uint: + typeUint.save(w) + x.save(w) + case Float32: + typeFloat32.save(w) + x.save(w) + case Float64: + typeFloat64.save(w) + x.save(w) + case Nil: + typeNil.save(w) + x.save(w) + case *Ref: + typeRef.save(w) + x.save(w) + case *String: + typeString.save(w) + x.save(w) + case *Slice: + typeSlice.save(w) + x.save(w) + case *Array: + typeArray.save(w) + x.save(w) + case *Map: + typeMap.save(w) + x.save(w) + case *Struct: + typeStruct.save(w) + x.save(w) + case noObjects: + typeNoObjects.save(w) + x.save(w) + case *multipleObjects: + typeMultipleObjects.save(w) + x.save(w) + case *Interface: + typeInterface.save(w) + x.save(w) + case *Type: + typeType.save(w) + x.save(w) + case *Complex64: + typeComplex64.save(w) + x.save(w) + case *Complex128: + typeComplex128.save(w) + x.save(w) + default: + panic(fmt.Errorf("unknown type: %#v", obj)) + } +} + +// Load loads a new object. +// +// +checkescape all +// +// N.B. This function will panic on error. +func Load(r Reader) Object { + switch hdr := loadUint(r); hdr { + case typeBool: + return loadBool(r) + case typeInt: + return loadInt(r) + case typeUint: + return loadUint(r) + case typeFloat32: + return loadFloat32(r) + case typeFloat64: + return loadFloat64(r) + case typeNil: + return loadNil(r) + case typeRef: + return ((*Ref)(nil)).load(r) // Escapes. + case typeString: + return ((*String)(nil)).load(r) // Escapes. + case typeSlice: + return ((*Slice)(nil)).load(r) // Escapes. + case typeArray: + return ((*Array)(nil)).load(r) // Escapes. + case typeMap: + return ((*Map)(nil)).load(r) // Escapes. + case typeStruct: + return ((*Struct)(nil)).load(r) // Escapes. + case typeNoObjects: // Special for struct. + return loadNoObjects(r) + case typeMultipleObjects: // Special for struct. + return ((*multipleObjects)(nil)).load(r) // Escapes. + case typeInterface: + return ((*Interface)(nil)).load(r) // Escapes. + case typeComplex64: + return ((*Complex64)(nil)).load(r) // Escapes. + case typeComplex128: + return ((*Complex128)(nil)).load(r) // Escapes. + case typeType: + return ((*Type)(nil)).load(r) // Escapes. + default: + // This is not a valid stream? + panic(fmt.Errorf("unknown header: %d", hdr)) + } +} + +// LoadUint loads a single unsigned integer. +// +// N.B. This function will panic on error. +func LoadUint(r Reader) uint64 { + return uint64(loadUint(r)) +} + +// SaveUint saves a single unsigned integer. +// +// N.B. This function will panic on error. +func SaveUint(w Writer, v uint64) { + Uint(v).save(w) +} diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index d0d77e19c..4d47207f7 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -33,6 +33,7 @@ go_library( "aliases.go", "memmove_unsafe.go", "mutex_unsafe.go", + "nocopy.go", "norace_unsafe.go", "race_unsafe.go", "rwmutex_unsafe.go", diff --git a/pkg/sync/nocopy.go b/pkg/sync/nocopy.go new file mode 100644 index 000000000..722b29501 --- /dev/null +++ b/pkg/sync/nocopy.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sync + +// NoCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type NoCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +func (*NoCopy) Unlock() {} diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 8ff922c69..5ae10939d 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -22,7 +22,7 @@ import ( // Mapping for tcpip.Error types. var ( ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL) + ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index c73072c42..798e07b01 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -61,6 +61,7 @@ var ( ENOMEM = error(syscall.ENOMEM) ENOSPC = error(syscall.ENOSPC) ENOSYS = error(syscall.ENOSYS) + ENOTCONN = error(syscall.ENOTCONN) ENOTDIR = error(syscall.ENOTDIR) ENOTEMPTY = error(syscall.ENOTEMPTY) ENOTSOCK = error(syscall.ENOTSOCK) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index c1745ba6a..ee264b726 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -320,6 +320,22 @@ func DstPort(port uint16) TransportChecker { } } +// NoChecksum creates a checker that checks if the checksum is zero. +func NoChecksum(noChecksum bool) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + udp, ok := h.(header.UDP) + if !ok { + return + } + + if b := udp.Checksum() == 0; b != noChecksum { + t.Errorf("bad checksum state, got %t, want %t", b, noChecksum) + } + } +} + // SeqNum creates a checker that checks the sequence number. func SeqNum(seq uint32) TransportChecker { return func(t *testing.T, h header.Transport) { diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD index 0cde694dc..d87797617 100644 --- a/pkg/tcpip/header/BUILD +++ b/pkg/tcpip/header/BUILD @@ -48,7 +48,7 @@ go_test( "//pkg/rand", "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) @@ -64,6 +64,6 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go index 718a4720a..83189676e 100644 --- a/pkg/tcpip/header/arp.go +++ b/pkg/tcpip/header/arp.go @@ -14,14 +14,33 @@ package header -import "gvisor.dev/gvisor/pkg/tcpip" +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/tcpip" +) const ( // ARPProtocolNumber is the ARP network protocol number. ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 // ARPSize is the size of an IPv4-over-Ethernet ARP packet. - ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 + ARPSize = 28 +) + +// ARPHardwareType is the hardware type for LinkEndpoint in an ARP header. +type ARPHardwareType uint16 + +// Typical ARP HardwareType values. Some of the constants have to be specific +// values as they are egressed on the wire in the HTYPE field of an ARP header. +const ( + ARPHardwareNone ARPHardwareType = 0 + // ARPHardwareEther specifically is the HTYPE for Ethernet as specified + // in the IANA list here: + // + // https://www.iana.org/assignments/arp-parameters/arp-parameters.xhtml#arp-parameters-2 + ARPHardwareEther ARPHardwareType = 1 + ARPHardwareLoopback ARPHardwareType = 2 ) // ARPOp is an ARP opcode. @@ -36,54 +55,64 @@ const ( // ARP is an ARP packet stored in a byte array as described in RFC 826. type ARP []byte -func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } -func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } -func (a ARP) hardwareAddressSize() int { return int(a[4]) } -func (a ARP) protocolAddressSize() int { return int(a[5]) } +const ( + hTypeOffset = 0 + protocolOffset = 2 + haAddressSizeOffset = 4 + protoAddressSizeOffset = 5 + opCodeOffset = 6 + senderHAAddressOffset = 8 + senderProtocolAddressOffset = senderHAAddressOffset + EthernetAddressSize + targetHAAddressOffset = senderProtocolAddressOffset + IPv4AddressSize + targetProtocolAddressOffset = targetHAAddressOffset + EthernetAddressSize +) + +func (a ARP) hardwareAddressType() ARPHardwareType { + return ARPHardwareType(binary.BigEndian.Uint16(a[hTypeOffset:])) +} + +func (a ARP) protocolAddressSpace() uint16 { return binary.BigEndian.Uint16(a[protocolOffset:]) } +func (a ARP) hardwareAddressSize() int { return int(a[haAddressSizeOffset]) } +func (a ARP) protocolAddressSize() int { return int(a[protoAddressSizeOffset]) } // Op is the ARP opcode. -func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } +func (a ARP) Op() ARPOp { return ARPOp(binary.BigEndian.Uint16(a[opCodeOffset:])) } // SetOp sets the ARP opcode. func (a ARP) SetOp(op ARPOp) { - a[6] = uint8(op >> 8) - a[7] = uint8(op) + binary.BigEndian.PutUint16(a[opCodeOffset:], uint16(op)) } // SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet. func (a ARP) SetIPv4OverEthernet() { - a[0], a[1] = 0, 1 // htypeEthernet - a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber - a[4] = 6 // macSize - a[5] = uint8(IPv4AddressSize) + binary.BigEndian.PutUint16(a[hTypeOffset:], uint16(ARPHardwareEther)) + binary.BigEndian.PutUint16(a[protocolOffset:], uint16(IPv4ProtocolNumber)) + a[haAddressSizeOffset] = EthernetAddressSize + a[protoAddressSizeOffset] = uint8(IPv4AddressSize) } // HardwareAddressSender is the link address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressSender() []byte { - const s = 8 - return a[s : s+6] + return a[senderHAAddressOffset : senderHAAddressOffset+EthernetAddressSize] } // ProtocolAddressSender is the protocol address of the sender. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressSender() []byte { - const s = 8 + 6 - return a[s : s+4] + return a[senderProtocolAddressOffset : senderProtocolAddressOffset+IPv4AddressSize] } // HardwareAddressTarget is the link address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) HardwareAddressTarget() []byte { - const s = 8 + 6 + 4 - return a[s : s+6] + return a[targetHAAddressOffset : targetHAAddressOffset+EthernetAddressSize] } // ProtocolAddressTarget is the protocol address of the target. // It is a view on to the ARP packet so it can be used to set the value. func (a ARP) ProtocolAddressTarget() []byte { - const s = 8 + 6 + 4 + 6 - return a[s : s+4] + return a[targetProtocolAddressOffset : targetProtocolAddressOffset+IPv4AddressSize] } // IsValid reports whether this is an ARP packet for IPv4 over Ethernet. @@ -91,10 +120,8 @@ func (a ARP) IsValid() bool { if len(a) < ARPSize { return false } - const htypeEthernet = 1 - const macSize = 6 - return a.hardwareAddressSpace() == htypeEthernet && + return a.hardwareAddressType() == ARPHardwareEther && a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && - a.hardwareAddressSize() == macSize && + a.hardwareAddressSize() == EthernetAddressSize && a.protocolAddressSize() == IPv4AddressSize } diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go index b1e92d2d7..eaface8cb 100644 --- a/pkg/tcpip/header/eth.go +++ b/pkg/tcpip/header/eth.go @@ -53,6 +53,10 @@ const ( // (all bits set to 0). unspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00") + // EthernetBroadcastAddress is an ethernet address that addresses every node + // on a local link. + EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff") + // unicastMulticastFlagMask is the mask of the least significant bit in // the first octet (in network byte order) of an ethernet address that // determines whether the ethernet address is a unicast or multicast. If diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 7908c5744..1a631b31a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -72,6 +72,7 @@ const ( // Values for ICMP code as defined in RFC 792. const ( ICMPv4TTLExceeded = 0 + ICMPv4HostUnreachable = 1 ICMPv4PortUnreachable = 3 ICMPv4FragmentationNeeded = 4 ) diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go index c7ee2de57..a13b4b809 100644 --- a/pkg/tcpip/header/icmpv6.go +++ b/pkg/tcpip/header/icmpv6.go @@ -110,9 +110,16 @@ const ( ICMPv6RedirectMsg ICMPv6Type = 137 ) -// Values for ICMP code as defined in RFC 4443. +// Values for ICMP destination unreachable code as defined in RFC 4443 section +// 3.1. const ( - ICMPv6PortUnreachable = 4 + ICMPv6NetworkUnreachable = 0 + ICMPv6Prohibited = 1 + ICMPv6BeyondScope = 2 + ICMPv6AddressUnreachable = 3 + ICMPv6PortUnreachable = 4 + ICMPv6Policy = 5 + ICMPv6RejectRoute = 6 ) // Type is the ICMP type field. diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD index b8b93e78e..39ca774ef 100644 --- a/pkg/tcpip/link/channel/BUILD +++ b/pkg/tcpip/link/channel/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go index 20b183da0..e12a5929b 100644 --- a/pkg/tcpip/link/channel/channel.go +++ b/pkg/tcpip/link/channel/channel.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -296,3 +297,12 @@ func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle { func (e *Endpoint) RemoveNotify(handle *NotificationHandle) { e.q.RemoveNotify(handle) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD index aa6db9aea..507b44abc 100644 --- a/pkg/tcpip/link/fdbased/BUILD +++ b/pkg/tcpip/link/fdbased/BUILD @@ -15,6 +15,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/binary", + "//pkg/iovec", "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index f34082e1a..c18bb91fb 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -45,6 +45,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/iovec" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -385,26 +386,35 @@ const ( _VIRTIO_NET_HDR_GSO_TCPV6 = 4 ) -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { if e.hdrSize > 0 { // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } // Preserve the src address if it's set in the route. - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) } +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if e.hdrSize > 0 { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) + } + + var builder iovec.Builder fd := e.fds[pkt.Hash%uint32(len(e.fds))] if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { @@ -430,47 +440,28 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne } vnetHdrBuf := binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - return rawfile.NonBlockingWrite3(fd, vnetHdrBuf, pkt.Header.View(), pkt.Data.ToView()) + builder.Add(vnetHdrBuf) } - if pkt.Data.Size() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Header.View()) - } - if pkt.Header.UsedLength() == 0 { - return rawfile.NonBlockingWrite(fd, pkt.Data.ToView()) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) } - return rawfile.NonBlockingWrite3(fd, pkt.Header.View(), pkt.Data.ToView(), nil) + return rawfile.NonBlockingWriteIovec(fd, builder.Build()) } func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tcpip.Error) { // Send a batch of packets through batchFD. mmsgHdrs := make([]rawfile.MMsgHdr, 0, len(batch)) for _, pkt := range batch { - var ethHdrBuf []byte - iovLen := 0 if e.hdrSize > 0 { - // Add ethernet header if needed. - ethHdrBuf = make([]byte, header.EthernetMinimumSize) - eth := header.Ethernet(ethHdrBuf) - ethHdr := &header.EthernetFields{ - DstAddr: pkt.EgressRoute.RemoteLinkAddress, - Type: pkt.NetworkProtocolNumber, - } - - // Preserve the src address if it's set in the route. - if pkt.EgressRoute.LocalLinkAddress != "" { - ethHdr.SrcAddr = pkt.EgressRoute.LocalLinkAddress - } else { - ethHdr.SrcAddr = e.addr - } - eth.Encode(ethHdr) - iovLen++ + e.AddHeader(pkt.EgressRoute.LocalLinkAddress, pkt.EgressRoute.RemoteLinkAddress, pkt.NetworkProtocolNumber, pkt) } - vnetHdr := virtioNetHdr{} var vnetHdrBuf []byte if e.Capabilities()&stack.CapabilityHardwareGSO != 0 { + vnetHdr := virtioNetHdr{} if pkt.GSOOptions != nil { vnetHdr.hdrLen = uint16(pkt.Header.UsedLength()) if pkt.GSOOptions.NeedsCsum { @@ -491,45 +482,19 @@ func (e *endpoint) sendBatch(batchFD int, batch []*stack.PacketBuffer) (int, *tc } } vnetHdrBuf = binary.Marshal(make([]byte, 0, virtioNetHdrSize), binary.LittleEndian, vnetHdr) - iovLen++ } - iovecs := make([]syscall.Iovec, iovLen+1+len(pkt.Data.Views())) + var builder iovec.Builder + builder.Add(vnetHdrBuf) + builder.Add(pkt.Header.View()) + for _, v := range pkt.Data.Views() { + builder.Add(v) + } + iovecs := builder.Build() + var mmsgHdr rawfile.MMsgHdr mmsgHdr.Msg.Iov = &iovecs[0] - iovecIdx := 0 - if vnetHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = &vnetHdrBuf[0] - v.Len = uint64(len(vnetHdrBuf)) - iovecIdx++ - } - if ethHdrBuf != nil { - v := &iovecs[iovecIdx] - v.Base = ðHdrBuf[0] - v.Len = uint64(len(ethHdrBuf)) - iovecIdx++ - } - pktSize := uint64(0) - // Encode L3 Header - v := &iovecs[iovecIdx] - hdr := &pkt.Header - hdrView := hdr.View() - v.Base = &hdrView[0] - v.Len = uint64(len(hdrView)) - pktSize += v.Len - iovecIdx++ - - // Now encode the Transport Payload. - pktViews := pkt.Data.Views() - for i := range pktViews { - vec := &iovecs[iovecIdx] - iovecIdx++ - vec.Base = &pktViews[i][0] - vec.Len = uint64(len(pktViews[i])) - pktSize += vec.Len - } - mmsgHdr.Msg.Iovlen = uint64(iovecIdx) + mmsgHdr.Msg.Iovlen = uint64(len(iovecs)) mmsgHdrs = append(mmsgHdrs, mmsgHdr) } @@ -626,6 +591,14 @@ func (e *endpoint) GSOMaxSize() uint32 { return e.gsoMaxSize } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + // InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes // to the FD, but does not read from it. All reads come from injected packets. type InjectableEndpoint struct { diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index eaee7e5d7..7b995b85a 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -107,6 +107,10 @@ func (c *context) DeliverNetworkPacket(remote tcpip.LinkAddress, local tcpip.Lin c.ch <- packetInfo{remote, protocol, pkt} } +func (c *context) DeliverOutboundPacket(remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestNoEthernetProperties(t *testing.T) { c := newContext(t, &Options{MTU: mtu}) defer c.cleanup() @@ -500,3 +504,80 @@ func TestRecvMMsgDispatcherCapLength(t *testing.T) { } } + +// fakeNetworkDispatcher delivers packets to pkts. +type fakeNetworkDispatcher struct { + pkts []*stack.PacketBuffer +} + +func (d *fakeNetworkDispatcher) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + d.pkts = append(d.pkts, pkt) +} + +func (d *fakeNetworkDispatcher) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestDispatchPacketFormat(t *testing.T) { + for _, test := range []struct { + name string + newDispatcher func(fd int, e *endpoint) (linkDispatcher, error) + }{ + { + name: "readVDispatcher", + newDispatcher: newReadVDispatcher, + }, + { + name: "recvMMsgDispatcher", + newDispatcher: newRecvMMsgDispatcher, + }, + } { + t.Run(test.name, func(t *testing.T) { + // Create a socket pair to send/recv. + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Fatal(err) + } + defer syscall.Close(fds[0]) + defer syscall.Close(fds[1]) + + data := []byte{ + // Ethernet header. + 1, 2, 3, 4, 5, 60, + 1, 2, 3, 4, 5, 61, + 8, 0, + // Mock network header. + 40, 41, 42, 43, + } + err = syscall.Sendmsg(fds[1], data, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + + // Create and run dispatcher once. + sink := &fakeNetworkDispatcher{} + d, err := test.newDispatcher(fds[0], &endpoint{ + hdrSize: header.EthernetMinimumSize, + dispatcher: sink, + }) + if err != nil { + t.Fatal(err) + } + if ok, err := d.dispatch(); !ok || err != nil { + t.Fatalf("d.dispatch() = %v, %v", ok, err) + } + + // Verify packet. + if got, want := len(sink.pkts), 1; got != want { + t.Fatalf("len(sink.pkts) = %d, want %d", got, want) + } + pkt := sink.pkts[0] + if got, want := len(pkt.LinkHeader), header.EthernetMinimumSize; got != want { + t.Errorf("len(pkt.LinkHeader) = %d, want %d", got, want) + } + if got, want := pkt.Data.Size(), 4; got != want { + t.Errorf("pkt.Data.Size() = %d, want %d", got, want) + } + }) + } +} diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index f04738cfb..d8f2504b3 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -278,7 +278,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { eth header.Ethernet ) if d.e.hdrSize > 0 { - eth = header.Ethernet(d.views[k][0]) + eth = header.Ethernet(d.views[k][0][:header.EthernetMinimumSize]) p = eth.Type() remote = eth.SourceAddress() local = eth.DestinationAddress() diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go index 568c6874f..781cdd317 100644 --- a/pkg/tcpip/link/loopback/loopback.go +++ b/pkg/tcpip/link/loopback/loopback.go @@ -113,3 +113,11 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { return nil } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareLoopback +} + +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD index 82b441b79..e7493e5c5 100644 --- a/pkg/tcpip/link/muxed/BUILD +++ b/pkg/tcpip/link/muxed/BUILD @@ -9,6 +9,7 @@ go_library( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go index c69d6b7e9..56a611825 100644 --- a/pkg/tcpip/link/muxed/injectable.go +++ b/pkg/tcpip/link/muxed/injectable.go @@ -18,6 +18,7 @@ package muxed import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -129,6 +130,15 @@ func (m *InjectableEndpoint) Wait() { } } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*InjectableEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unsupported operation") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*InjectableEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +} + // NewInjectableEndpoint creates a new multi-endpoint injectable endpoint. func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint { return &InjectableEndpoint{ diff --git a/pkg/tcpip/link/nested/BUILD b/pkg/tcpip/link/nested/BUILD new file mode 100644 index 000000000..2cdb23475 --- /dev/null +++ b/pkg/tcpip/link/nested/BUILD @@ -0,0 +1,32 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "nested", + srcs = [ + "nested.go", + ], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) + +go_test( + name = "nested_test", + size = "small", + srcs = [ + "nested_test.go", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/nested/nested.go b/pkg/tcpip/link/nested/nested.go new file mode 100644 index 000000000..d40de54df --- /dev/null +++ b/pkg/tcpip/link/nested/nested.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package nested provides helpers to implement the pattern of nested +// stack.LinkEndpoints. +package nested + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// Endpoint is a wrapper around stack.LinkEndpoint and stack.NetworkDispatcher +// that can be used to implement nesting safely by providing lifecycle +// concurrency guards. +// +// See the tests in this package for example usage. +type Endpoint struct { + child stack.LinkEndpoint + embedder stack.NetworkDispatcher + + // mu protects dispatcher. + mu sync.RWMutex + dispatcher stack.NetworkDispatcher +} + +var _ stack.GSOEndpoint = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) +var _ stack.NetworkDispatcher = (*Endpoint)(nil) + +// Init initializes a nested.Endpoint that uses embedder as the dispatcher for +// child on Attach. +// +// See the tests in this package for example usage. +func (e *Endpoint) Init(child stack.LinkEndpoint, embedder stack.NetworkDispatcher) { + e.child = child + e.embedder = embedder +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverNetworkPacket(remote, local, protocol, pkt) + } +} + +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.mu.RLock() + d := e.dispatcher + e.mu.RUnlock() + if d != nil { + d.DeliverOutboundPacket(remote, local, protocol, pkt) + } +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + e.dispatcher = dispatcher + e.mu.Unlock() + // If we're attaching to a valid dispatcher, pass embedder as the dispatcher + // to our child, otherwise detach the child by giving it a nil dispatcher. + var pass stack.NetworkDispatcher + if dispatcher != nil { + pass = e.embedder + } + e.child.Attach(pass) +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + e.mu.RLock() + isAttached := e.dispatcher != nil + e.mu.RUnlock() + return isAttached +} + +// MTU implements stack.LinkEndpoint. +func (e *Endpoint) MTU() uint32 { + return e.child.MTU() +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.child.Capabilities() +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return e.child.MaxHeaderLength() +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.child.LinkAddress() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + return e.child.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + return e.child.WritePackets(r, gso, pkts, protocol) +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + return e.child.WriteRawPacket(vv) +} + +// Wait implements stack.LinkEndpoint. +func (e *Endpoint) Wait() { + e.child.Wait() +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *Endpoint) GSOMaxSize() uint32 { + if e, ok := e.child.(stack.GSOEndpoint); ok { + return e.GSOMaxSize() + } + return 0 +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.child.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.child.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/nested/nested_test.go b/pkg/tcpip/link/nested/nested_test.go new file mode 100644 index 000000000..7d9249c1c --- /dev/null +++ b/pkg/tcpip/link/nested/nested_test.go @@ -0,0 +1,109 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nested_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type parentEndpoint struct { + nested.Endpoint +} + +var _ stack.LinkEndpoint = (*parentEndpoint)(nil) +var _ stack.NetworkDispatcher = (*parentEndpoint)(nil) + +type childEndpoint struct { + stack.LinkEndpoint + dispatcher stack.NetworkDispatcher +} + +var _ stack.LinkEndpoint = (*childEndpoint)(nil) + +func (c *childEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + c.dispatcher = dispatcher +} + +func (c *childEndpoint) IsAttached() bool { + return c.dispatcher != nil +} + +type counterDispatcher struct { + count int +} + +var _ stack.NetworkDispatcher = (*counterDispatcher)(nil) + +func (d *counterDispatcher) DeliverNetworkPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + d.count++ +} + +func (d *counterDispatcher) DeliverOutboundPacket(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) { + panic("unimplemented") +} + +func TestNestedLinkEndpoint(t *testing.T) { + const emptyAddress = tcpip.LinkAddress("") + + var ( + childEP childEndpoint + nestedEP parentEndpoint + disp counterDispatcher + ) + nestedEP.Endpoint.Init(&childEP, &nestedEP) + + if childEP.IsAttached() { + t.Error("On init, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("On init, nestedEP.IsAttached() = true, want = false") + } + + nestedEP.Attach(&disp) + if disp.count != 0 { + t.Fatalf("After attach, got disp.count = %d, want = 0", disp.count) + } + if !childEP.IsAttached() { + t.Error("After attach, childEP.IsAttached() = false, want = true") + } + if !nestedEP.IsAttached() { + t.Error("After attach, nestedEP.IsAttached() = false, want = true") + } + + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 1 { + t.Errorf("After first packet with dispatcher attached, got disp.count = %d, want = 1", disp.count) + } + + nestedEP.Attach(nil) + if childEP.IsAttached() { + t.Error("After detach, childEP.IsAttached() = true, want = false") + } + if nestedEP.IsAttached() { + t.Error("After detach, nestedEP.IsAttached() = true, want = false") + } + + disp.count = 0 + nestedEP.DeliverNetworkPacket(emptyAddress, emptyAddress, header.IPv4ProtocolNumber, &stack.PacketBuffer{}) + if disp.count != 0 { + t.Errorf("After second packet with dispatcher detached, got disp.count = %d, want = 0", disp.count) + } + +} diff --git a/pkg/tcpip/link/packetsocket/BUILD b/pkg/tcpip/link/packetsocket/BUILD new file mode 100644 index 000000000..6fff160ce --- /dev/null +++ b/pkg/tcpip/link/packetsocket/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "packetsocket", + srcs = ["endpoint.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/packetsocket/endpoint.go b/pkg/tcpip/link/packetsocket/endpoint.go new file mode 100644 index 000000000..3922c2a04 --- /dev/null +++ b/pkg/tcpip/link/packetsocket/endpoint.go @@ -0,0 +1,50 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packetsocket provides a link layer endpoint that provides the ability +// to loop outbound packets to any AF_PACKET sockets that may be interested in +// the outgoing packet. +package packetsocket + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type endpoint struct { + nested.Endpoint +} + +// New creates a new packetsocket LinkEndpoint. +func New(lower stack.LinkEndpoint) stack.LinkEndpoint { + e := &endpoint{} + e.Endpoint.Init(lower, e) + return e +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.Endpoint.DeliverOutboundPacket(r.RemoteLinkAddress, r.LocalLinkAddress, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) +} + +// WritePackets implements stack.LinkEndpoint.WritePackets. +func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.Endpoint.DeliverOutboundPacket(pkt.EgressRoute.RemoteLinkAddress, pkt.EgressRoute.LocalLinkAddress, pkt.NetworkProtocolNumber, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} diff --git a/pkg/tcpip/link/qdisc/fifo/BUILD b/pkg/tcpip/link/qdisc/fifo/BUILD index 054c213bc..1d0079bd6 100644 --- a/pkg/tcpip/link/qdisc/fifo/BUILD +++ b/pkg/tcpip/link/qdisc/fifo/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b5dfb7850..467083239 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -106,6 +107,11 @@ func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) +} + // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher @@ -193,6 +199,8 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { + // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } @@ -207,3 +215,13 @@ func (e *endpoint) Wait() { e.wg.Wait() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 44e25d475..f4c32c2da 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -66,39 +66,14 @@ func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { return nil } -// NonBlockingWrite3 writes up to three byte slices to a file descriptor in a -// single syscall. It fails if partial data is written. -func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error { - // If the is no second buffer, issue a regular write. - if len(b2) == 0 { - return NonBlockingWrite(fd, b1) - } - - // We have two buffers. Build the iovec that represents them and issue - // a writev syscall. - iovec := [3]syscall.Iovec{ - { - Base: &b1[0], - Len: uint64(len(b1)), - }, - { - Base: &b2[0], - Len: uint64(len(b2)), - }, - } - iovecLen := uintptr(2) - - if len(b3) > 0 { - iovecLen++ - iovec[2].Base = &b3[0] - iovec[2].Len = uint64(len(b3)) - } - +// NonBlockingWriteIovec writes iovec to a file descriptor in a single syscall. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []syscall.Iovec) *tcpip.Error { + iovecLen := uintptr(len(iovec)) _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) if e != 0 { return TranslateErrno(e) } - return nil } diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go index 0374a2441..507c76b76 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem.go +++ b/pkg/tcpip/link/sharedmem/sharedmem.go @@ -183,22 +183,29 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress { return e.addr } -// WritePacket writes outbound packets to the file descriptor. If it is not -// currently writable, the packet is dropped. -func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - // Add the ethernet header here. +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + // Add ethernet header if needed. eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) pkt.LinkHeader = buffer.View(eth) ethHdr := &header.EthernetFields{ - DstAddr: r.RemoteLinkAddress, + DstAddr: remote, Type: protocol, } - if r.LocalLinkAddress != "" { - ethHdr.SrcAddr = r.LocalLinkAddress + + // Preserve the src address if it's set in the route. + if local != "" { + ethHdr.SrcAddr = local } else { ethHdr.SrcAddr = e.addr } eth.Encode(ethHdr) +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) WritePacket(r *stack.Route, _ *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(r.LocalLinkAddress, r.RemoteLinkAddress, protocol, pkt) v := pkt.Data.ToView() // Transmit the packet. @@ -287,3 +294,8 @@ func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) { e.completed.Done() } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType +func (*endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go index 28a2e88ba..8f3cd9449 100644 --- a/pkg/tcpip/link/sharedmem/sharedmem_test.go +++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go @@ -143,6 +143,10 @@ func (c *testContext) DeliverNetworkPacket(remoteLinkAddr, localLinkAddr tcpip.L c.packetCh <- struct{}{} } +func (c *testContext) DeliverOutboundPacket(remoteLinkAddr, localLinkAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (c *testContext) cleanup() { c.ep.Close() closeFDs(&c.txCfg) diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 230a8d53a..7cbc305e7 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index f2e47b6a7..509076643 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -48,18 +49,21 @@ var LogPackets uint32 = 1 var LogPacketsToPCAP uint32 = 1 type endpoint struct { - dispatcher stack.NetworkDispatcher - lower stack.LinkEndpoint + nested.Endpoint writer io.Writer maxPCAPLen uint32 } +var _ stack.GSOEndpoint = (*endpoint)(nil) +var _ stack.LinkEndpoint = (*endpoint)(nil) +var _ stack.NetworkDispatcher = (*endpoint)(nil) + // New creates a new sniffer link-layer endpoint. It wraps around another // endpoint and logs packets and they traverse the endpoint. func New(lower stack.LinkEndpoint) stack.LinkEndpoint { - return &endpoint{ - lower: lower, - } + sniffer := &endpoint{} + sniffer.Endpoint.Init(lower, sniffer) + return sniffer } func zoneOffset() (int32, error) { @@ -103,11 +107,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( if err := writePCAPHeader(writer, snapLen); err != nil { return nil, err } - return &endpoint{ - lower: lower, + sniffer := &endpoint{ writer: writer, maxPCAPLen: snapLen, - }, nil + } + sniffer.Endpoint.Init(lower, sniffer) + return sniffer, nil } // DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is @@ -115,50 +120,12 @@ func NewWithWriter(lower stack.LinkEndpoint, writer io.Writer, snapLen uint32) ( // logs the packet before forwarding to the actual dispatcher. func (e *endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { e.dumpPacket("recv", nil, protocol, pkt) - e.dispatcher.DeliverNetworkPacket(remote, local, protocol, pkt) -} - -// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher -// and registers with the lower endpoint as its dispatcher so that "e" is called -// for inbound packets. -func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - e.lower.Attach(e) -} - -// IsAttached implements stack.LinkEndpoint.IsAttached. -func (e *endpoint) IsAttached() bool { - return e.dispatcher != nil -} - -// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the -// lower endpoint. -func (e *endpoint) MTU() uint32 { - return e.lower.MTU() + e.Endpoint.DeliverNetworkPacket(remote, local, protocol, pkt) } -// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the -// request to the lower endpoint. -func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.lower.Capabilities() -} - -// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards -// the request to the lower endpoint. -func (e *endpoint) MaxHeaderLength() uint16 { - return e.lower.MaxHeaderLength() -} - -func (e *endpoint) LinkAddress() tcpip.LinkAddress { - return e.lower.LinkAddress() -} - -// GSOMaxSize returns the maximum GSO packet size. -func (e *endpoint) GSOMaxSize() uint32 { - if gso, ok := e.lower.(stack.GSOEndpoint); ok { - return gso.GSOMaxSize() - } - return 0 +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.Endpoint.DeliverOutboundPacket(remote, local, protocol, pkt) } func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { @@ -203,7 +170,7 @@ func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.Netw // forwards the request to the lower endpoint. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { e.dumpPacket("send", gso, protocol, pkt) - return e.lower.WritePacket(r, gso, protocol, pkt) + return e.Endpoint.WritePacket(r, gso, protocol, pkt) } // WritePackets implements the stack.LinkEndpoint interface. It is called by @@ -213,7 +180,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.dumpPacket("send", gso, protocol, pkt) } - return e.lower.WritePackets(r, gso, pkts, protocol) + return e.Endpoint.WritePackets(r, gso, pkts, protocol) } // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. @@ -221,12 +188,9 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { e.dumpPacket("send", nil, 0, &stack.PacketBuffer{ Data: vv, }) - return e.lower.WriteRawPacket(vv) + return e.Endpoint.WriteRawPacket(vv) } -// Wait implements stack.LinkEndpoint.Wait. -func (e *endpoint) Wait() { e.lower.Wait() } - func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) { // Figure out the network layer info. var transProto uint8 diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 6bc9033d0..04ae58e59 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -139,6 +139,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE stack: s, nicID: id, name: name, + isTap: prefix == "tap", } endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { @@ -271,21 +272,9 @@ func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) { if d.hasFlags(linux.IFF_TAP) { // Add ethernet header if not provided. if info.Pkt.LinkHeader == nil { - hdr := &header.EthernetFields{ - SrcAddr: info.Route.LocalLinkAddress, - DstAddr: info.Route.RemoteLinkAddress, - Type: info.Proto, - } - if hdr.SrcAddr == "" { - hdr.SrcAddr = d.endpoint.LinkAddress() - } - - eth := make(header.Ethernet, header.EthernetMinimumSize) - eth.Encode(hdr) - vv.AppendView(buffer.View(eth)) - } else { - vv.AppendView(info.Pkt.LinkHeader) + d.endpoint.AddHeader(info.Route.LocalLinkAddress, info.Route.RemoteLinkAddress, info.Proto, info.Pkt) } + vv.AppendView(info.Pkt.LinkHeader) } // Append upper headers. @@ -348,6 +337,7 @@ type tunEndpoint struct { stack *stack.Stack nicID tcpip.NICID name string + isTap bool } // DecRef decrements refcount of e, removes NIC if refcount goes to 0. @@ -356,3 +346,38 @@ func (e *tunEndpoint) DecRef() { e.stack.RemoveNIC(e.nicID) }) } + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *tunEndpoint) ARPHardwareType() header.ARPHardwareType { + if e.isTap { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *tunEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.isTap { + return + } + eth := header.Ethernet(pkt.Header.Prepend(header.EthernetMinimumSize)) + pkt.LinkHeader = buffer.View(eth) + hdr := &header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: protocol, + } + if hdr.SrcAddr == "" { + hdr.SrcAddr = e.LinkAddress() + } + + eth.Encode(hdr) +} + +// MaxHeaderLength returns the maximum size of the link layer header. +func (e *tunEndpoint) MaxHeaderLength() uint16 { + if e.isTap { + return header.EthernetMinimumSize + } + return 0 +} diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD index 0956d2c65..ee84c3d96 100644 --- a/pkg/tcpip/link/waitable/BUILD +++ b/pkg/tcpip/link/waitable/BUILD @@ -12,6 +12,7 @@ go_library( "//pkg/gate", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) @@ -25,6 +26,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/header", "//pkg/tcpip/stack", ], ) diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go index 949b3f2b2..b152a0f26 100644 --- a/pkg/tcpip/link/waitable/waitable.go +++ b/pkg/tcpip/link/waitable/waitable.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/gate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,6 +60,15 @@ func (e *Endpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protoco e.dispatchGate.Leave() } +// DeliverOutboundPacket implements stack.NetworkDispatcher.DeliverOutboundPacket. +func (e *Endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + if !e.dispatchGate.Enter() { + return + } + e.dispatcher.DeliverOutboundPacket(remote, local, protocol, pkt) + e.dispatchGate.Leave() +} + // Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and // registers with the lower endpoint as its dispatcher so that "e" is called // for inbound packets. @@ -147,3 +157,13 @@ func (e *Endpoint) WaitDispatch() { // Wait implements stack.LinkEndpoint.Wait. func (e *Endpoint) Wait() {} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *Endpoint) ARPHardwareType() header.ARPHardwareType { + return e.lower.ARPHardwareType() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *Endpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + e.lower.AddHeader(local, remote, protocol, pkt) +} diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go index 63bf40562..c448a888f 100644 --- a/pkg/tcpip/link/waitable/waitable_test.go +++ b/pkg/tcpip/link/waitable/waitable_test.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -39,6 +40,10 @@ func (e *countedEndpoint) DeliverNetworkPacket(remote, local tcpip.LinkAddress, e.dispatchCount++ } +func (e *countedEndpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.attachCount++ e.dispatcher = dispatcher @@ -81,9 +86,19 @@ func (e *countedEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return nil } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*countedEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("unimplemented") +} + // Wait implements stack.LinkEndpoint.Wait. func (*countedEndpoint) Wait() {} +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *countedEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("unimplemented") +} + func TestWaitWrite(t *testing.T) { ep := &countedEndpoint{} wep := New(ep) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7f27a840d..31a242482 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { r := &stack.Route{ - RemoteLinkAddress: broadcastMAC, + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) @@ -181,7 +184,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { if addr == header.IPv4Broadcast { - return broadcastMAC, true + return header.EthernetBroadcastAddress, true } if header.IsV4MulticastAddress(addr) { return header.EthernetAddressFromMulticastIPv4Address(addr), true @@ -216,8 +219,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return 0, false, true } -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - // NewProtocol returns an ARP network protocol. func NewProtocol() stack.NetworkProtocol { return &protocol{} diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 66e67429c..a35a64a0f 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -32,10 +32,14 @@ import ( ) const ( - stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + + defaultChannelSize = 1 + defaultMTU = 65536 ) type testContext struct { @@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext { TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, }) - const defaultMTU = 65536 - ep := channel.New(256, defaultMTU, stackLinkAddr) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) wep := stack.LinkEndpoint(ep) if testing.Verbose() { @@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength()) } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { @@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) { t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) } } + +func TestLinkAddressRequest(t *testing.T) { + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: stackLinkAddr2, + expectLinkAddr: stackLinkAddr2, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: header.EthernetBroadcastAddress, + }, + } + + for _, test := range tests { + p := arp.NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) + if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 2982450f8..ffbadb6e2 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -17,6 +17,7 @@ package fragmentation import ( + "errors" "fmt" "log" "time" @@ -25,20 +26,31 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" ) -// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. -const DefaultReassembleTimeout = 30 * time.Second +const ( + // DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. + DefaultReassembleTimeout = 30 * time.Second -// HighFragThreshold is the threshold at which we start trimming old -// fragmented packets. Linux uses a default value of 4 MB. See -// net.ipv4.ipfrag_high_thresh for more information. -const HighFragThreshold = 4 << 20 // 4MB + // HighFragThreshold is the threshold at which we start trimming old + // fragmented packets. Linux uses a default value of 4 MB. See + // net.ipv4.ipfrag_high_thresh for more information. + HighFragThreshold = 4 << 20 // 4MB -// LowFragThreshold is the threshold we reach to when we start dropping -// older fragmented packets. It's important that we keep enough room for newer -// packets to be re-assembled. Hence, this needs to be lower than -// HighFragThreshold enough. Linux uses a default value of 3 MB. See -// net.ipv4.ipfrag_low_thresh for more information. -const LowFragThreshold = 3 << 20 // 3MB + // LowFragThreshold is the threshold we reach to when we start dropping + // older fragmented packets. It's important that we keep enough room for newer + // packets to be re-assembled. Hence, this needs to be lower than + // HighFragThreshold enough. Linux uses a default value of 3 MB. See + // net.ipv4.ipfrag_low_thresh for more information. + LowFragThreshold = 3 << 20 // 3MB + + // minBlockSize is the minimum block size for fragments. + minBlockSize = 1 +) + +var ( + // ErrInvalidArgs indicates to the caller that that an invalid argument was + // provided. + ErrInvalidArgs = errors.New("invalid args") +) // Fragmentation is the main structure that other modules // of the stack should use to implement IP Fragmentation. @@ -50,10 +62,13 @@ type Fragmentation struct { rList reassemblerList size int timeout time.Duration + blockSize uint16 } // NewFragmentation creates a new Fragmentation. // +// blockSize specifies the fragment block size, in bytes. +// // highMemoryLimit specifies the limit on the memory consumed // by the fragments stored by Fragmentation (overhead of internal data-structures // is not accounted). Fragments are dropped when the limit is reached. @@ -64,7 +79,7 @@ type Fragmentation struct { // reassemblingTimeout specifies the maximum time allowed to reassemble a packet. // Fragments are lazily evicted only when a new a packet with an // already existing fragmentation-id arrives after the timeout. -func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { +func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { if lowMemoryLimit >= highMemoryLimit { lowMemoryLimit = highMemoryLimit } @@ -73,17 +88,46 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t lowMemoryLimit = 0 } + if blockSize < minBlockSize { + blockSize = minBlockSize + } + return &Fragmentation{ reassemblers: make(map[uint32]*reassembler), highLimit: highMemoryLimit, lowLimit: lowMemoryLimit, timeout: reassemblingTimeout, + blockSize: blockSize, } } // Process processes an incoming fragment belonging to an ID and returns a // complete packet when all the packets belonging to that ID have been received. +// +// [first, last] is the range of the fragment bytes. +// +// first must be a multiple of the block size f is configured with. The size +// of the fragment data must be a multiple of the block size, unless there are +// no fragments following this fragment (more set to false). func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { + if first > last { + return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + } + + if first%f.blockSize != 0 { + return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + } + + fragmentSize := last - first + 1 + if more && fragmentSize%f.blockSize != 0 { + return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + } + + if l := vv.Size(); l < int(fragmentSize) { + return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + } + vv.CapLength(int(fragmentSize)) + f.mu.Lock() r, ok := f.reassemblers[id] if ok && r.tooOld(f.timeout) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 72c0f53be..ebc3232e5 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -15,6 +15,7 @@ package fragmentation import ( + "errors" "reflect" "testing" "time" @@ -81,7 +82,7 @@ var processTestCases = []struct { func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { - f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout) for i, in := range c.in { vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) if err != nil { @@ -110,7 +111,7 @@ func TestFragmentationProcess(t *testing.T) { func TestReassemblingTimeout(t *testing.T) { timeout := time.Millisecond - f := NewFragmentation(1024, 512, timeout) + f := NewFragmentation(minBlockSize, 1024, 512, timeout) // Send first fragment with id = 0, first = 0, last = 0, and more = true. f.Process(0, 0, 0, true, vv(1, "0")) // Sleep more than the timeout. @@ -127,7 +128,7 @@ func TestReassemblingTimeout(t *testing.T) { } func TestMemoryLimits(t *testing.T) { - f := NewFragmentation(3, 1, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout) // Send first fragment with id = 0. f.Process(0, 0, 0, true, vv(1, "0")) // Send first fragment with id = 1. @@ -151,7 +152,7 @@ func TestMemoryLimits(t *testing.T) { } func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { - f := NewFragmentation(1, 0, DefaultReassembleTimeout) + f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout) // Send first fragment with id = 0. f.Process(0, 0, 0, true, vv(1, "0")) // Send the same packet again. @@ -163,3 +164,99 @@ func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) } } + +func TestErrors(t *testing.T) { + const fragID = 5 + + tests := []struct { + name string + blockSize uint16 + first uint16 + last uint16 + more bool + data string + err error + }{ + { + name: "exact block size without more", + blockSize: 2, + first: 2, + last: 3, + more: false, + data: "01", + }, + { + name: "exact block size with more", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "01", + }, + { + name: "exact block size with more and extra data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "012", + }, + { + name: "exact block size with more and too little data", + blockSize: 2, + first: 2, + last: 3, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size with more", + blockSize: 2, + first: 2, + last: 2, + more: true, + data: "0", + err: ErrInvalidArgs, + }, + { + name: "not exact block size without more", + blockSize: 2, + first: 2, + last: 2, + more: false, + data: "0", + }, + { + name: "first not a multiple of block size", + blockSize: 2, + first: 3, + last: 4, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + { + name: "first more than last", + blockSize: 2, + first: 4, + last: 3, + more: true, + data: "01", + err: ErrInvalidArgs, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout) + _, done, err := f.Process(fragID, test.first, test.last, test.more, vv(len(test.data), test.data)) + if !errors.Is(err, test.err) { + t.Errorf("got Proceess(%d, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", fragID, test.first, test.last, test.more, test.data, err, test.err) + } + if done { + t.Errorf("got Proceess(%d, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", fragID, test.first, test.last, test.more, test.data) + } + }) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 7c8fb3e0a..615bae648 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -172,14 +172,24 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, protocol tcpip.Ne } // WritePackets implements stack.LinkEndpoint.WritePackets. -func (t *testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { +func (*testObject) WritePackets(_ *stack.Route, _ *stack.GSO, pkt stack.PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { panic("not implemented") } -func (t *testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { +func (*testObject) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testObject) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (*testObject) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + panic("not implemented") +} + func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 78420d6e6..d142b4ffa 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -34,6 +34,6 @@ go_test( "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 1b67aa066..83e71cb8c 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -129,6 +129,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { pkt.Data.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { + case header.ICMPv4HostUnreachable: + e.handleControl(stack.ControlNoRoute, 0, pkt) + case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 7e9f16c90..0b5a35cce 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -45,6 +45,10 @@ const ( // buckets is the number of identifier buckets. buckets = 2048 + + // The size of a fragment block, in bytes, as per RFC 791 section 3.1, + // page 14. + fragmentblockSize = 8 ) type endpoint struct { @@ -66,7 +70,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi prefixLen: addrWithPrefix.PrefixLen, linkEP: linkEP, dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, stack: st, } @@ -225,12 +229,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadSize int, params stack.NetworkHeaderParams) header.IPv4 { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payloadSize) - id := uint32(0) - if length > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) - } + // RFC 6864 section 4.3 mandates uniqueness of ID values for non-atomic + // datagrams. Since the DF bit is never being set here, all datagrams + // are non-atomic and need an ID. + id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, @@ -376,13 +378,12 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // Set the packet ID when zero. if ip.ID() == 0 { - id := uint32(0) - if pkt.Data.Size() > header.IPv4MaximumHeaderSize+8 { - // Packets of 68 bytes or less are required by RFC 791 to not be - // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1) + // RFC 6864 section 4.3 mandates uniqueness of ID values for + // non-atomic datagrams, so assign an ID to all such datagrams + // according to the definition given in RFC 6864 section 4. + if ip.Flags()&header.IPv4FlagDontFragment == 0 || ip.Flags()&header.IPv4FlagMoreFragments != 0 || ip.FragmentOffset() > 0 { + ip.SetID(uint16(atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1))) } - ip.SetID(uint16(id)) } // Always set the checksum. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 11e579c4b..4f82c45e2 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -519,6 +519,11 @@ func TestReceiveFragments(t *testing.T) { // UDP header plus a payload of 0..256 in increments of 2. ipv4Payload2 := udpGen(128, 2) udpPayload2 := ipv4Payload2[header.UDPMinimumSize:] + // UDP header plus a payload of 0..256 in increments of 3. + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 791 section 3.1 page 14). + ipv4Payload3 := udpGen(127, 3) + udpPayload3 := ipv4Payload3[header.UDPMinimumSize:] type fragmentData struct { id uint16 @@ -545,6 +550,18 @@ func TestReceiveFragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload1}, }, { + name: "No fragmentation with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + id: 1, + flags: 0, + fragmentOffset: 0, + payload: ipv4Payload3, + }, + }, + expectedPayloads: [][]byte{udpPayload3}, + }, + { name: "More fragments without payload", fragments: []fragmentData{ { @@ -587,6 +604,42 @@ func TestReceiveFragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload1}, }, { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3[:64], + }, + { + id: 1, + flags: 0, + fragmentOffset: 64, + payload: ipv4Payload3[64:], + }, + }, + expectedPayloads: [][]byte{udpPayload3}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload3[:63], + }, + { + id: 1, + flags: 0, + fragmentOffset: 63, + payload: ipv4Payload3[63:], + }, + }, + expectedPayloads: nil, + }, + { name: "Second fragment has MoreFlags set", fragments: []fragmentData{ { diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index 3f71fc520..feada63dc 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -39,6 +39,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 2ff7eedf4..24600d877 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -128,6 +128,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } pkt.Data.TrimFront(header.ICMPv6DstUnreachableMinimumSize) switch header.ICMPv6(hdr).Code() { + case header.ICMPv6NetworkUnreachable: + e.handleControl(stack.ControlNetworkUnreachable, 0, pkt) case header.ICMPv6PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, pkt) } @@ -494,8 +496,6 @@ const ( icmpV6LengthOffset = 25 ) -var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) - var _ stack.LinkAddressResolver = (*protocol)(nil) // LinkAddressProtocol implements stack.LinkAddressResolver. @@ -504,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { snaddr := header.SolicitedNodeAddr(addr) // TODO(b/148672031): Use stack.FindRoute instead of manually creating the @@ -513,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack. r := &stack.Route{ LocalAddress: localAddr, RemoteAddress: snaddr, - RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr), + RemoteLinkAddress: remoteLinkAddr, + } + if len(r.RemoteLinkAddress) == 0 { + r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr) } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) pkt.SetType(header.ICMPv6NeighborSolicit) diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 52a01b44e..f86aaed1d 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -34,6 +34,9 @@ const ( linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") + + defaultChannelSize = 1 + defaultMTU = 65536 ) var ( @@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext { }), } - const defaultMTU = 65536 - c.linkEP0 = channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0) wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0}) if testing.Verbose() { @@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext { t.Fatalf("AddAddress lladdr0: %v", err) } - c.linkEP1 = channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) @@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { }) } } + +func TestLinkAddressRequest(t *testing.T) { + snaddr := header.SolicitedNodeAddr(lladdr0) + mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) + + tests := []struct { + name string + remoteLinkAddr tcpip.LinkAddress + expectLinkAddr tcpip.LinkAddress + }{ + { + name: "Unicast", + remoteLinkAddr: linkAddr1, + expectLinkAddr: linkAddr1, + }, + { + name: "Multicast", + remoteLinkAddr: "", + expectLinkAddr: mcaddr, + }, + } + + for _, test := range tests { + p := NewProtocol() + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver") + } + + linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) + if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) + } + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 95fbcf2d1..5483ae4ee 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -467,7 +467,7 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi linkEP: linkEP, linkAddrCache: linkAddrCache, dispatcher: dispatcher, - fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), protocol: p, }, nil } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 213ff64f2..84bac14ff 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -678,13 +678,18 @@ type fragmentData struct { } func TestReceiveIPv6Fragments(t *testing.T) { - const nicID = 1 - const udpPayload1Length = 256 - const udpPayload2Length = 128 - const fragmentExtHdrLen = 8 - // Note, not all routing extension headers will be 8 bytes but this test - // uses 8 byte routing extension headers for most sub tests. - const routingExtHdrLen = 8 + const ( + nicID = 1 + udpPayload1Length = 256 + udpPayload2Length = 128 + // Used to test cases where the fragment blocks are not a multiple of + // the fragment block size of 8 (RFC 8200 section 4.5). + udpPayload3Length = 127 + fragmentExtHdrLen = 8 + // Note, not all routing extension headers will be 8 bytes but this test + // uses 8 byte routing extension headers for most sub tests. + routingExtHdrLen = 8 + ) udpGen := func(payload []byte, multiplier uint8) buffer.View { payloadLen := len(payload) @@ -716,6 +721,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload2 := udpPayload2Buf[:] ipv6Payload2 := udpGen(udpPayload2, 2) + var udpPayload3Buf [udpPayload3Length]byte + udpPayload3 := udpPayload3Buf[:] + ipv6Payload3 := udpGen(udpPayload3, 3) + tests := []struct { name string expectedPayload []byte @@ -751,6 +760,24 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload1}, }, { + name: "Atomic fragment with size not a multiple of fragment block size", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3), + []buffer.View{ + // Fragment extension header. + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 0, 0, 0, 0, 0}), + + ipv6Payload3, + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3}, + }, + { name: "Two fragments", fragments: []fragmentData{ { @@ -785,6 +812,74 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload1}, }, { + name: "Two fragments with last fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3[:64], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload3}, + }, + { + name: "Two fragments with first fragment size not a multiple of fragment block size", + fragments: []fragmentData{ + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload3[:63], + }, + ), + }, + { + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload3)-63, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload3[63:], + }, + ), + }, + }, + expectedPayloads: nil, + }, + { name: "Two fragments with different IDs", fragments: []fragmentData{ { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index edc29ad27..f6d592eb5 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -52,6 +52,9 @@ type Flags struct { // // LoadBalanced takes precidence over MostRecent. LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool } // Bits converts the Flags to their bitset form. @@ -63,6 +66,9 @@ func (f Flags) Bits() BitFlags { if f.LoadBalanced { rf |= LoadBalancedFlag } + if f.TupleOnly { + rf |= TupleOnlyFlag + } return rf } @@ -98,6 +104,9 @@ const ( // LoadBalancedFlag represents Flags.LoadBalanced. LoadBalancedFlag + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + // nextFlag is the value that the next added flag will have. // // It is used to calculate FlagMask below. It is also the number of @@ -106,6 +115,10 @@ const ( // FlagMask is a bit mask for BitFlags. FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag ) // ToFlags converts the bitset into a Flags struct. @@ -113,6 +126,7 @@ func (f BitFlags) ToFlags() Flags { return Flags{ MostRecent: f&MostRecentFlag != 0, LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, } } @@ -175,9 +189,54 @@ func (c FlagCounter) IntersectionRefs() BitFlags { return intersection } +type destination struct { + addr tcpip.Address + port uint16 +} + +func makeDestination(a tcpip.FullAddress) destination { + return destination{ + a.Addr, + a.Port, + } +} + +// portNode is never empty. When it has no elements, it is removed from the +// map that references it. +type portNode map[destination]FlagCounter + +// intersectionRefs calculates the intersection of flag bit values which affect +// the specified destination. +// +// If no destinations are present, all flag values are returned as there are no +// entries to limit possible flag values of a new entry. +// +// In addition to the intersection, the number of intersecting refs is +// returned. +func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { + intersection := FlagMask + var count int + + for d, f := range p { + if d == dst { + intersection &= f.IntersectionRefs() + count++ + continue + } + // Wildcard destinations affect all destinations for TupleOnly. + if d.addr == anyIPAddress || dst.addr == anyIPAddress { + // Only bitwise and the TupleOnlyFlag. + intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + count++ + } + } + + return intersection, count +} + // deviceNode is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]FlagCounter +type deviceNode map[tcpip.NICID]portNode // isAvailable checks whether binding is possible by device. If not binding to a // device, check against all FlagCounters. If binding to a specific device, check @@ -186,17 +245,15 @@ type deviceNode map[tcpip.NICID]FlagCounter // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { +func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { flagBits := flags.Bits() if bindToDevice == 0 { - // Trying to binding all devices. - if flagBits == 0 { - // Can't bind because the (addr,port) is already bound. - return false - } intersection := FlagMask for _, p := range d { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) + if c == 0 { + continue + } intersection &= i if intersection&flagBits == 0 { // Can't bind because the (addr,port) was @@ -210,16 +267,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool { intersection := FlagMask if p, ok := d[0]; ok { - intersection = p.IntersectionRefs() - if intersection&flagBits == 0 { + var c int + intersection, c = p.intersectionRefs(dst) + if c > 0 && intersection&flagBits == 0 { return false } } if p, ok := d[bindToDevice]; ok { - i := p.IntersectionRefs() + i, c := p.intersectionRefs(dst) intersection &= i - if intersection&flagBits == 0 { + if c > 0 && intersection&flagBits == 0 { return false } } @@ -233,12 +291,12 @@ type bindAddresses map[tcpip.Address]deviceNode // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID) bool { +func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { if addr == anyIPAddress { // If binding to the "any" address then check that there are no conflicts // with all addresses. for _, d := range b { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -247,14 +305,14 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice // Check that there is no conflict with the "any" address. if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } // Check that this is no conflict with the provided address. if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice) { + if !d.isAvailable(flags, bindToDevice, dst) { return false } } @@ -320,17 +378,17 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) + return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice) { + if !addrs.isAvailable(addr, flags, bindToDevice, dst) { return false } } @@ -342,14 +400,16 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() + dst := makeDestination(dest) + // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice) { + if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } return port, nil @@ -357,15 +417,16 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice), nil + return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { return false } + flagBits := flags.Bits() // Reserve port on all network protocols. @@ -381,9 +442,65 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber d = make(deviceNode) m[addr] = d } - n := d[bindToDevice] + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + n := p[dst] n.AddRef(flagBits) - d[bindToDevice] = n + p[dst] = n + d[bindToDevice] = p + } + + return true +} + +// ReserveTuple adds a port reservation for the tuple on all given protocol. +func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { + flagBits := flags.Bits() + dst := makeDestination(dest) + + s.mu.Lock() + defer s.mu.Unlock() + + // It is easier to undo the entire reservation, so if we find that the + // tuple can't be fully added, finish and undo the whole thing. + undo := false + + // Reserve port on all network protocols. + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + p := d[bindToDevice] + if p == nil { + p = make(portNode) + } + + n := p[dst] + if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + // Tuple already exists. + undo = true + } + n.AddRef(flagBits) + p[dst] = n + d[bindToDevice] = p + } + + if undo { + // releasePortLocked decrements the counts (rather than setting + // them to zero), so it will undo the incorrect incrementing + // above. + s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + return false } return true @@ -391,12 +508,14 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { s.mu.Lock() defer s.mu.Unlock() - flagBits := flags.Bits() + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) +} +func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { @@ -404,21 +523,32 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp if !ok { continue } - n, ok := d[bindToDevice] + p, ok := d[bindToDevice] if !ok { continue } - n.refs[flagBits]-- - d[bindToDevice] = n - if n.TotalRefs() == 0 { - delete(d, bindToDevice) + n, ok := p[dst] + if !ok { + continue + } + n.DropRef(flags) + if n.TotalRefs() > 0 { + p[dst] = n + continue } - if len(d) == 0 { - delete(m, addr) + delete(p, dst) + if len(p) > 0 { + continue + } + delete(d, bindToDevice) + if len(d) > 0 { + continue } - if len(m) == 0 { - delete(s.allocatedPorts, desc) + delete(m, addr) + if len(m) > 0 { + continue } + delete(s.allocatedPorts, desc) } } } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index d6969d050..58db5868c 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -36,6 +36,7 @@ type portReserveTestAction struct { flags Flags release bool device tcpip.NICID + dest tcpip.FullAddress } func TestPortReservation(t *testing.T) { @@ -272,6 +273,54 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true, LoadBalanced: true}, want: nil}, {port: 24, ip: fakeIPAddress, flags: Flags{MostRecent: true}, want: tcpip.ErrPortInUse}, }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard with reuseaddr, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + }, + }, { + tname: "bind tuple with reuseaddr, and then wildcard", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind two tuples with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind two tuples", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: nil}, + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 25}, want: nil}, + }, + }, { + tname: "bind wildcard, and then tuple with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, dest: tcpip.FullAddress{}, want: nil}, + {port: 24, ip: fakeIPAddress, flags: Flags{TupleOnly: true}, dest: tcpip.FullAddress{Addr: fakeIPAddress, Port: 24}, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind wildcard twice with reuseaddr", + actions: []portReserveTestAction{ + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + {port: 24, ip: anyIPAddress, flags: Flags{TupleOnly: true}, want: nil}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -280,19 +329,18 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d) = %v, want %v", test.ip, test.port, test.flags, test.device, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) } } }) - } } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 24f52b735..1c58bed2d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -16,6 +16,18 @@ go_template_instance( ) go_template_instance( + name = "neighbor_entry_list", + out = "neighbor_entry_list.go", + package = "stack", + prefix = "neighborEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*neighborEntry", + "Linker": "*neighborEntry", + }, +) + +go_template_instance( name = "packet_buffer_list", out = "packet_buffer_list.go", package = "stack", @@ -27,6 +39,18 @@ go_template_instance( }, ) +go_template_instance( + name = "tuple_list", + out = "tuple_list.go", + package = "stack", + prefix = "tuple", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*tuple", + "Linker": "*tuple", + }, +) + go_library( name = "stack", srcs = [ @@ -35,12 +59,18 @@ go_library( "forwarder.go", "icmp_rate_limit.go", "iptables.go", + "iptables_state.go", "iptables_targets.go", "iptables_types.go", "linkaddrcache.go", "linkaddrentry_list.go", "ndp.go", + "neighbor_cache.go", + "neighbor_entry.go", + "neighbor_entry_list.go", + "neighborstate_string.go", "nic.go", + "nud.go", "packet_buffer.go", "packet_buffer_list.go", "rand.go", @@ -48,7 +78,9 @@ go_library( "route.go", "stack.go", "stack_global_state.go", + "stack_options.go", "transport_demuxer.go", + "tuple_list.go", ], visibility = ["//visibility:public"], deps = [ @@ -74,10 +106,12 @@ go_test( size = "medium", srcs = [ "ndp_test.go", + "nud_test.go", "stack_test.go", "transport_demuxer_test.go", "transport_test.go", ], + shard_count = 20, deps = [ ":stack", "//pkg/rand", @@ -93,7 +127,8 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/udp", "//pkg/waiter", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) @@ -101,8 +136,11 @@ go_test( name = "stack_test", size = "small", srcs = [ + "fake_time_test.go", "forwarder_test.go", "linkaddrcache_test.go", + "neighbor_cache_test.go", + "neighbor_entry_test.go", "nic_test.go", ], library = ":stack", @@ -112,5 +150,8 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "@com_github_dpjacques_clockwork//:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 05bf62788..559a1c4dd 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -26,280 +26,321 @@ import ( ) // Connection tracking is used to track and manipulate packets for NAT rules. -// The connection is created for a packet if it does not exist. Every connection -// contains two tuples (original and reply). The tuples are manipulated if there -// is a matching NAT rule. The packet is modified by looking at the tuples in the -// Prerouting and Output hooks. +// The connection is created for a packet if it does not exist. Every +// connection contains two tuples (original and reply). The tuples are +// manipulated if there is a matching NAT rule. The packet is modified by +// looking at the tuples in the Prerouting and Output hooks. +// +// Currently, only TCP tracking is supported. + +// Our hash table has 16K buckets. +// TODO(gvisor.dev/issue/170): These should be tunable. +const numBuckets = 1 << 14 // Direction of the tuple. -type ctDirection int +type direction int const ( - dirOriginal ctDirection = iota + dirOriginal direction = iota dirReply ) -// Status of connection. -// TODO(gvisor.dev/issue/170): Add other states of connection. -type connStatus int - -const ( - connNew connStatus = iota - connEstablished -) - // Manipulation type for the connection. type manipType int const ( - manipDstPrerouting manipType = iota + manipNone manipType = iota + manipDstPrerouting manipDstOutput ) -// connTrackMutable is the manipulatable part of the tuple. -type connTrackMutable struct { - // addr is source address of the tuple. - addr tcpip.Address - - // port is source port of the tuple. - port uint16 - - // protocol is network layer protocol. - protocol tcpip.NetworkProtocolNumber -} - -// connTrackImmutable is the non-manipulatable part of the tuple. -type connTrackImmutable struct { - // addr is destination address of the tuple. - addr tcpip.Address +// tuple holds a connection's identifying and manipulating data in one +// direction. It is immutable. +// +// +stateify savable +type tuple struct { + // tupleEntry is used to build an intrusive list of tuples. + tupleEntry - // direction is direction (original or reply) of the tuple. - direction ctDirection + tupleID - // port is destination port of the tuple. - port uint16 + // conn is the connection tracking entry this tuple belongs to. + conn *conn - // protocol is transport layer protocol. - protocol tcpip.TransportProtocolNumber + // direction is the direction of the tuple. + direction direction } -// connTrackTuple represents the tuple which is created from the -// packet. -type connTrackTuple struct { - // dst is non-manipulatable part of the tuple. - dst connTrackImmutable - - // src is manipulatable part of the tuple. - src connTrackMutable +// tupleID uniquely identifies a connection in one direction. It currently +// contains enough information to distinguish between any TCP or UDP +// connection, and will need to be extended to support other protocols. +// +// +stateify savable +type tupleID struct { + srcAddr tcpip.Address + srcPort uint16 + dstAddr tcpip.Address + dstPort uint16 + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber } -// connTrackTupleHolder is the container of tuple and connection. -type ConnTrackTupleHolder struct { - // conn is pointer to the connection tracking entry. - conn *connTrack - - // tuple is original or reply tuple. - tuple connTrackTuple +// reply creates the reply tupleID. +func (ti tupleID) reply() tupleID { + return tupleID{ + srcAddr: ti.dstAddr, + srcPort: ti.dstPort, + dstAddr: ti.srcAddr, + dstPort: ti.srcPort, + transProto: ti.transProto, + netProto: ti.netProto, + } } -// connTrack is the connection. -type connTrack struct { - // originalTupleHolder contains tuple in original direction. - originalTupleHolder ConnTrackTupleHolder - - // replyTupleHolder contains tuple in reply direction. - replyTupleHolder ConnTrackTupleHolder - - // status indicates connection is new or established. - status connStatus +// conn is a tracked connection. +// +// +stateify savable +type conn struct { + // original is the tuple in original direction. It is immutable. + original tuple - // timeout indicates the time connection should be active. - timeout time.Duration + // reply is the tuple in reply direction. It is immutable. + reply tuple - // manip indicates if the packet should be manipulated. + // manip indicates if the packet should be manipulated. It is immutable. manip manipType - // tcb is TCB control block. It is used to keep track of states - // of tcp connection. - tcb tcpconntrack.TCB - // tcbHook indicates if the packet is inbound or outbound to - // update the state of tcb. + // update the state of tcb. It is immutable. tcbHook Hook + + // mu protects all mutable state. + mu sync.Mutex `state:"nosave"` + // tcb is TCB control block. It is used to keep track of states + // of tcp connection and is protected by mu. + tcb tcpconntrack.TCB + // lastUsed is the last time the connection saw a relevant packet, and + // is updated by each packet on the connection. It is protected by mu. + lastUsed time.Time `state:".(unixTime)"` } -// ConnTrackTable contains a map of all existing connections created for -// NAT rules. -type ConnTrackTable struct { - // connMu protects connTrackTable. - connMu sync.RWMutex +// timedOut returns whether the connection timed out based on its state. +func (cn *conn) timedOut(now time.Time) bool { + const establishedTimeout = 5 * 24 * time.Hour + const defaultTimeout = 120 * time.Second + cn.mu.Lock() + defer cn.mu.Unlock() + if cn.tcb.State() == tcpconntrack.ResultAlive { + // Use the same default as Linux, which doesn't delete + // established connections for 5(!) days. + return now.Sub(cn.lastUsed) > establishedTimeout + } + // Use the same default as Linux, which lets connections in most states + // other than established remain for <= 120 seconds. + return now.Sub(cn.lastUsed) > defaultTimeout +} - // connTrackTable maintains a map of tuples needed for connection tracking - // for iptables NAT rules. The key for the map is an integer calculated - // using seed, source address, destination address, source port and - // destination port. - CtMap map[uint32]ConnTrackTupleHolder +// update the connection tracking state. +// +// Precondition: ct.mu must be held. +func (ct *conn) updateLocked(tcpHeader header.TCP, hook Hook) { + // Update the state of tcb. tcb assumes it's always initialized on the + // client. However, we only need to know whether the connection is + // established or not, so the client/server distinction isn't important. + // TODO(gvisor.dev/issue/170): Add support in tcpconntrack to handle + // other tcp states. + if ct.tcb.IsEmpty() { + ct.tcb.Init(tcpHeader) + } else if hook == ct.tcbHook { + ct.tcb.UpdateStateOutbound(tcpHeader) + } else { + ct.tcb.UpdateStateInbound(tcpHeader) + } +} +// ConnTrack tracks all connections created for NAT rules. Most users are +// expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop. +// +// ConnTrack keeps all connections in a slice of buckets, each of which holds a +// linked list of tuples. This gives us some desirable properties: +// - Each bucket has its own lock, lessening lock contention. +// - The slice is large enough that lists stay short (<10 elements on average). +// Thus traversal is fast. +// - During linked list traversal we reap expired connections. This amortizes +// the cost of reaping them and makes reapUnused faster. +// +// Locks are ordered by their location in the buckets slice. That is, a +// goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j. +// +// +stateify savable +type ConnTrack struct { // seed is a one-time random value initialized at stack startup - // and is used in calculation of hash key for connection tracking - // table. - Seed uint32 + // and is used in the calculation of hash keys for the list of buckets. + // It is immutable. + seed uint32 + + // mu protects the buckets slice, but not buckets' contents. Only take + // the write lock if you are modifying the slice or saving for S/R. + mu sync.RWMutex `state:"nosave"` + + // buckets is protected by mu. + buckets []bucket } -// packetToTuple converts packet to a tuple in original direction. -func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) { - var tuple connTrackTuple +// +stateify savable +type bucket struct { + // mu protects tuples. + mu sync.Mutex `state:"nosave"` + tuples tupleList +} - netHeader := header.IPv4(pkt.NetworkHeader) +// packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid +// TCP header. +func packetToTupleID(pkt *PacketBuffer) (tupleID, *tcpip.Error) { // TODO(gvisor.dev/issue/170): Need to support for other // protocols as well. + netHeader := header.IPv4(pkt.NetworkHeader) if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return tuple, tcpip.ErrUnknownProtocol + return tupleID{}, tcpip.ErrUnknownProtocol } - tuple.src.addr = netHeader.SourceAddress() - tuple.src.port = tcpHeader.SourcePort() - tuple.src.protocol = header.IPv4ProtocolNumber - - tuple.dst.addr = netHeader.DestinationAddress() - tuple.dst.port = tcpHeader.DestinationPort() - tuple.dst.protocol = netHeader.TransportProtocol() - - return tuple, nil + return tupleID{ + srcAddr: netHeader.SourceAddress(), + srcPort: tcpHeader.SourcePort(), + dstAddr: netHeader.DestinationAddress(), + dstPort: tcpHeader.DestinationPort(), + transProto: netHeader.TransportProtocol(), + netProto: header.IPv4ProtocolNumber, + }, nil } -// getReplyTuple creates reply tuple for the given tuple. -func getReplyTuple(tuple connTrackTuple) connTrackTuple { - var replyTuple connTrackTuple - replyTuple.src.addr = tuple.dst.addr - replyTuple.src.port = tuple.dst.port - replyTuple.src.protocol = tuple.src.protocol - replyTuple.dst.addr = tuple.src.addr - replyTuple.dst.port = tuple.src.port - replyTuple.dst.protocol = tuple.dst.protocol - replyTuple.dst.direction = dirReply - - return replyTuple +// newConn creates new connection. +func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn { + conn := conn{ + manip: manip, + tcbHook: hook, + lastUsed: time.Now(), + } + conn.original = tuple{conn: &conn, tupleID: orig} + conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply} + return &conn } -// makeNewConn creates new connection. -func makeNewConn(tuple, replyTuple connTrackTuple) connTrack { - var conn connTrack - conn.status = connNew - conn.originalTupleHolder.tuple = tuple - conn.originalTupleHolder.conn = &conn - conn.replyTupleHolder.tuple = replyTuple - conn.replyTupleHolder.conn = &conn +// connFor gets the conn for pkt if it exists, or returns nil +// if it does not. It returns an error when pkt does not contain a valid TCP +// header. +// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support +// other transport protocols. +func (ct *ConnTrack) connFor(pkt *PacketBuffer) (*conn, direction) { + tid, err := packetToTupleID(pkt) + if err != nil { + return nil, dirOriginal + } - return conn -} + bucket := ct.bucket(tid) + now := time.Now() + + ct.mu.RLock() + defer ct.mu.RUnlock() + ct.buckets[bucket].mu.Lock() + defer ct.buckets[bucket].mu.Unlock() + + // Iterate over the tuples in a bucket, cleaning up any unused + // connections we find. + for other := ct.buckets[bucket].tuples.Front(); other != nil; other = other.Next() { + // Clean up any timed-out connections we happen to find. + if ct.reapTupleLocked(other, bucket, now) { + // The tuple expired. + continue + } + if tid == other.tupleID { + return other.conn, other.direction + } + } -// getTupleHash returns hash of the tuple. The fields used for -// generating hash are seed (generated once for stack), source address, -// destination address, source port and destination ports. -func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 { - h := jenkins.Sum32(ct.Seed) - h.Write([]byte(tuple.src.addr)) - h.Write([]byte(tuple.dst.addr)) - portBuf := make([]byte, 2) - binary.LittleEndian.PutUint16(portBuf, tuple.src.port) - h.Write([]byte(portBuf)) - binary.LittleEndian.PutUint16(portBuf, tuple.dst.port) - h.Write([]byte(portBuf)) - - return h.Sum32() + return nil, dirOriginal } -// connTrackForPacket returns connTrack for packet. -// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other -// transport protocols. -func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) { - var dir ctDirection - tuple, err := packetToTuple(pkt, hook) +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt RedirectTarget) *conn { + tid, err := packetToTupleID(pkt) if err != nil { - return nil, dir - } - - ct.connMu.Lock() - defer ct.connMu.Unlock() - - connTrackTable := ct.CtMap - hash := ct.getTupleHash(tuple) - - var conn *connTrack - switch createConn { - case true: - // If connection does not exist for the hash, create a new - // connection. - replyTuple := getReplyTuple(tuple) - replyHash := ct.getTupleHash(replyTuple) - newConn := makeNewConn(tuple, replyTuple) - conn = &newConn - - // Add tupleHolders to the map. - // TODO(gvisor.dev/issue/170): Need to support collisions using linked list. - ct.CtMap[hash] = conn.originalTupleHolder - ct.CtMap[replyHash] = conn.replyTupleHolder - default: - tupleHolder, ok := connTrackTable[hash] - if !ok { - return nil, dir - } - - // If this is the reply of new connection, set the connection - // status as ESTABLISHED. - conn = tupleHolder.conn - if conn.status == connNew && tupleHolder.tuple.dst.direction == dirReply { - conn.status = connEstablished - } - if tupleHolder.conn == nil { - panic("tupleHolder has null connection tracking entry") - } + return nil + } + if hook != Prerouting && hook != Output { + return nil + } - dir = tupleHolder.tuple.dst.direction + // Create a new connection and change the port as per the iptables + // rule. This tuple will be used to manipulate the packet in + // handlePacket. + replyTID := tid.reply() + replyTID.srcAddr = rt.MinIP + replyTID.srcPort = rt.MinPort + var manip manipType + switch hook { + case Prerouting: + manip = manipDstPrerouting + case Output: + manip = manipDstOutput } - return conn, dir + conn := newConn(tid, replyTID, manip, hook) + ct.insertConn(conn) + return conn } -// SetNatInfo will manipulate the tuples according to iptables NAT rules. -func (ct *ConnTrackTable) SetNatInfo(pkt *PacketBuffer, rt RedirectTarget, hook Hook) { - // Get the connection. Connection is always created before this - // function is called. - conn, _ := ct.connTrackForPacket(pkt, hook, false) - if conn == nil { - panic("connection should be created to manipulate tuples.") +// insertConn inserts conn into the appropriate table bucket. +func (ct *ConnTrack) insertConn(conn *conn) { + // Lock the buckets in the correct order. + tupleBucket := ct.bucket(conn.original.tupleID) + replyBucket := ct.bucket(conn.reply.tupleID) + ct.mu.RLock() + defer ct.mu.RUnlock() + if tupleBucket < replyBucket { + ct.buckets[tupleBucket].mu.Lock() + ct.buckets[replyBucket].mu.Lock() + } else if tupleBucket > replyBucket { + ct.buckets[replyBucket].mu.Lock() + ct.buckets[tupleBucket].mu.Lock() + } else { + // Both tuples are in the same bucket. + ct.buckets[tupleBucket].mu.Lock() } - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) - // TODO(gvisor.dev/issue/170): Support only redirect of ports. Need to - // support changing of address for Prerouting. - - // Change the port as per the iptables rule. This tuple will be used - // to manipulate the packet in HandlePacket. - conn.replyTupleHolder.tuple.src.addr = rt.MinIP - conn.replyTupleHolder.tuple.src.port = rt.MinPort - newHash := ct.getTupleHash(conn.replyTupleHolder.tuple) + // Now that we hold the locks, ensure the tuple hasn't been inserted by + // another thread. + alreadyInserted := false + for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() { + if other.tupleID == conn.original.tupleID { + alreadyInserted = true + break + } + } - // Add the changed tuple to the map. - ct.connMu.Lock() - defer ct.connMu.Unlock() - ct.CtMap[newHash] = conn.replyTupleHolder - if hook == Output { - conn.replyTupleHolder.conn.manip = manipDstOutput + if !alreadyInserted { + // Add the tuple to the map. + ct.buckets[tupleBucket].tuples.PushFront(&conn.original) + ct.buckets[replyBucket].tuples.PushFront(&conn.reply) } - // Delete the old tuple. - delete(ct.CtMap, replyHash) + // Unlocking can happen in any order. + ct.buckets[tupleBucket].mu.Unlock() + if tupleBucket != replyBucket { + ct.buckets[replyBucket].mu.Unlock() + } } // handlePacketPrerouting manipulates ports for packets in Prerouting hook. -// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.. -func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) { +// TODO(gvisor.dev/issue/170): Change address for Prerouting hook. +func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -308,21 +349,31 @@ func handlePacketPrerouting(pkt *PacketBuffer, conn *connTrack, dir ctDirection) // modified. switch dir { case dirOriginal: - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) case dirReply: - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } + // TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated + // on inbound packets, so we don't recalculate them. However, we should + // support cases when they are validated, e.g. when we can't offload + // receive checksumming. + netHeader.SetChecksum(0) netHeader.SetChecksum(^netHeader.CalculateChecksum()) } // handlePacketOutput manipulates ports for packets in Output hook. -func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, dir ctDirection) { +func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) { + // If this is a noop entry, don't do anything. + if conn.manip == manipNone { + return + } + netHeader := header.IPv4(pkt.NetworkHeader) tcpHeader := header.TCP(pkt.TransportHeader) @@ -331,13 +382,13 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, // modified. For prerouting redirection, we only reach this point // when replying, so packet sources are modified. if conn.manip == manipDstOutput && dir == dirOriginal { - port := conn.replyTupleHolder.tuple.src.port + port := conn.reply.srcPort tcpHeader.SetDestinationPort(port) - netHeader.SetDestinationAddress(conn.replyTupleHolder.tuple.src.addr) + netHeader.SetDestinationAddress(conn.reply.srcAddr) } else { - port := conn.originalTupleHolder.tuple.dst.port + port := conn.original.dstPort tcpHeader.SetSourcePort(port) - netHeader.SetSourceAddress(conn.originalTupleHolder.tuple.dst.addr) + netHeader.SetSourceAddress(conn.original.dstAddr) } // Calculate the TCP checksum and set it. @@ -356,33 +407,32 @@ func handlePacketOutput(pkt *PacketBuffer, conn *connTrack, gso *GSO, r *Route, netHeader.SetChecksum(^netHeader.CalculateChecksum()) } -// HandlePacket will manipulate the port and address of the packet if the -// connection exists. -func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) { +// handlePacket will manipulate the port and address of the packet if the +// connection exists. Returns whether, after the packet traverses the tables, +// it should create a new entry in the table. +func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Route) bool { if pkt.NatDone { - return + return false } if hook != Prerouting && hook != Output { - return + return false } - conn, dir := ct.connTrackForPacket(pkt, hook, false) - // Connection or Rule not found for the packet. - if conn == nil { - return + // TODO(gvisor.dev/issue/170): Support other transport protocols. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return false } - netHeader := header.IPv4(pkt.NetworkHeader) - // TODO(gvisor.dev/issue/170): Need to support for other transport - // protocols as well. - if netHeader == nil || netHeader.TransportProtocol() != header.TCPProtocolNumber { - return + conn, dir := ct.connFor(pkt) + // Connection or Rule not found for the packet. + if conn == nil { + return true } tcpHeader := header.TCP(pkt.TransportHeader) if tcpHeader == nil { - return + return false } switch hook { @@ -396,39 +446,161 @@ func (ct *ConnTrackTable) HandlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r // Update the state of tcb. // TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle // other tcp states. - var st tcpconntrack.Result - if conn.tcb.IsEmpty() { - conn.tcb.Init(tcpHeader) - conn.tcbHook = hook - } else { - switch hook { - case conn.tcbHook: - st = conn.tcb.UpdateStateOutbound(tcpHeader) - default: - st = conn.tcb.UpdateStateInbound(tcpHeader) - } + conn.mu.Lock() + defer conn.mu.Unlock() + + // Mark the connection as having been used recently so it isn't reaped. + conn.lastUsed = time.Now() + // Update connection state. + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + + return false +} + +// maybeInsertNoop tries to insert a no-op connection entry to keep connections +// from getting clobbered when replies arrive. It only inserts if there isn't +// already a connection for pkt. +// +// This should be called after traversing iptables rules only, to ensure that +// pkt.NatDone is set correctly. +func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { + // If there were a rule applying to this packet, it would be marked + // with NatDone. + if pkt.NatDone { + return } - // Delete conntrack if tcp connection is closed. - if st == tcpconntrack.ResultClosedByPeer || st == tcpconntrack.ResultClosedBySelf || st == tcpconntrack.ResultReset { - ct.deleteConnTrack(conn) + // We only track TCP connections. + if pkt.NetworkHeader == nil || header.IPv4(pkt.NetworkHeader).TransportProtocol() != header.TCPProtocolNumber { + return } -} -// deleteConnTrack deletes the connection. -func (ct *ConnTrackTable) deleteConnTrack(conn *connTrack) { - if conn == nil { + // This is the first packet we're seeing for the TCP connection. Insert + // the noop entry (an identity mapping) so that the response doesn't + // get NATed, breaking the connection. + tid, err := packetToTupleID(pkt) + if err != nil { return } + conn := newConn(tid, tid.reply(), manipNone, hook) + conn.updateLocked(header.TCP(pkt.TransportHeader), hook) + ct.insertConn(conn) +} - tuple := conn.originalTupleHolder.tuple - hash := ct.getTupleHash(tuple) - replyTuple := conn.replyTupleHolder.tuple - replyHash := ct.getTupleHash(replyTuple) +// bucket gets the conntrack bucket for a tupleID. +func (ct *ConnTrack) bucket(id tupleID) int { + h := jenkins.Sum32(ct.seed) + h.Write([]byte(id.srcAddr)) + h.Write([]byte(id.dstAddr)) + shortBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(shortBuf, id.srcPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, id.dstPort) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto)) + h.Write([]byte(shortBuf)) + binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto)) + h.Write([]byte(shortBuf)) + ct.mu.RLock() + defer ct.mu.RUnlock() + return int(h.Sum32()) % len(ct.buckets) +} - ct.connMu.Lock() - defer ct.connMu.Unlock() +// reapUnused deletes timed out entries from the conntrack map. The rules for +// reaping are: +// - Most reaping occurs in connFor, which is called on each packet. connFor +// cleans up the bucket the packet's connection maps to. Thus calls to +// reapUnused should be fast. +// - Each call to reapUnused traverses a fraction of the conntrack table. +// Specifically, it traverses len(ct.buckets)/fractionPerReaping. +// - After reaping, reapUnused decides when it should next run based on the +// ratio of expired connections to examined connections. If the ratio is +// greater than maxExpiredPct, it schedules the next run quickly. Otherwise it +// slightly increases the interval between runs. +// - maxFullTraversal caps the time it takes to traverse the entire table. +// +// reapUnused returns the next bucket that should be checked and the time after +// which it should be called again. +func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) { + // TODO(gvisor.dev/issue/170): This can be more finely controlled, as + // it is in Linux via sysctl. + const fractionPerReaping = 128 + const maxExpiredPct = 50 + const maxFullTraversal = 60 * time.Second + const minInterval = 10 * time.Millisecond + const maxInterval = maxFullTraversal / fractionPerReaping + + now := time.Now() + checked := 0 + expired := 0 + var idx int + ct.mu.RLock() + defer ct.mu.RUnlock() + for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ { + idx = (i + start) % len(ct.buckets) + ct.buckets[idx].mu.Lock() + for tuple := ct.buckets[idx].tuples.Front(); tuple != nil; tuple = tuple.Next() { + checked++ + if ct.reapTupleLocked(tuple, idx, now) { + expired++ + } + } + ct.buckets[idx].mu.Unlock() + } + // We already checked buckets[idx]. + idx++ + + // If half or more of the connections are expired, the table has gotten + // stale. Reschedule quickly. + expiredPct := 0 + if checked != 0 { + expiredPct = expired * 100 / checked + } + if expiredPct > maxExpiredPct { + return idx, minInterval + } + if interval := prevInterval + minInterval; interval <= maxInterval { + // Increment the interval between runs. + return idx, interval + } + // We've hit the maximum interval. + return idx, maxInterval +} + +// reapTupleLocked tries to remove tuple and its reply from the table. It +// returns whether the tuple's connection has timed out. +// +// Preconditions: ct.mu is locked for reading and bucket is locked. +func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { + if !tuple.conn.timedOut(now) { + return false + } + + // To maintain lock order, we can only reap these tuples if the reply + // appears later in the table. + replyBucket := ct.bucket(tuple.reply()) + if bucket > replyBucket { + return true + } + + // Don't re-lock if both tuples are in the same bucket. + differentBuckets := bucket != replyBucket + if differentBuckets { + ct.buckets[replyBucket].mu.Lock() + } + + // We have the buckets locked and can remove both tuples. + if tuple.direction == dirOriginal { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.reply) + } else { + ct.buckets[replyBucket].tuples.Remove(&tuple.conn.original) + } + ct.buckets[bucket].tuples.Remove(tuple) + + // Don't re-unlock if both tuples are in the same bucket. + if differentBuckets { + ct.buckets[replyBucket].mu.Unlock() + } - delete(ct.CtMap, hash) - delete(ct.CtMap, replyHash) + return true } diff --git a/pkg/tcpip/stack/fake_time_test.go b/pkg/tcpip/stack/fake_time_test.go new file mode 100644 index 000000000..92c8cb534 --- /dev/null +++ b/pkg/tcpip/stack/fake_time_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "container/heap" + "sync" + "time" + + "github.com/dpjacques/clockwork" + "gvisor.dev/gvisor/pkg/tcpip" +) + +type fakeClock struct { + clock clockwork.FakeClock + + // mu protects the fields below. + mu sync.RWMutex + + // times is min-heap of times. A heap is used for quick retrieval of the next + // upcoming time of scheduled work. + times *timeHeap + + // waitGroups stores one WaitGroup for all work scheduled to execute at the + // same time via AfterFunc. This allows parallel execution of all functions + // passed to AfterFunc scheduled for the same time. + waitGroups map[time.Time]*sync.WaitGroup +} + +func newFakeClock() *fakeClock { + return &fakeClock{ + clock: clockwork.NewFakeClock(), + times: &timeHeap{}, + waitGroups: make(map[time.Time]*sync.WaitGroup), + } +} + +var _ tcpip.Clock = (*fakeClock)(nil) + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (fc *fakeClock) NowNanoseconds() int64 { + return fc.clock.Now().UnixNano() +} + +// NowMonotonic implements tcpip.Clock.NowMonotonic. +func (fc *fakeClock) NowMonotonic() int64 { + return fc.NowNanoseconds() +} + +// AfterFunc implements tcpip.Clock.AfterFunc. +func (fc *fakeClock) AfterFunc(d time.Duration, f func()) tcpip.Timer { + until := fc.clock.Now().Add(d) + wg := fc.addWait(until) + return &fakeTimer{ + clock: fc, + until: until, + timer: fc.clock.AfterFunc(d, func() { + defer wg.Done() + f() + }), + } +} + +// addWait adds an additional wait to the WaitGroup for parallel execution of +// all work scheduled for t. Returns a reference to the WaitGroup modified. +func (fc *fakeClock) addWait(t time.Time) *sync.WaitGroup { + fc.mu.RLock() + wg, ok := fc.waitGroups[t] + fc.mu.RUnlock() + + if ok { + wg.Add(1) + return wg + } + + fc.mu.Lock() + heap.Push(fc.times, t) + fc.mu.Unlock() + + wg = &sync.WaitGroup{} + wg.Add(1) + + fc.mu.Lock() + fc.waitGroups[t] = wg + fc.mu.Unlock() + + return wg +} + +// removeWait removes a wait from the WaitGroup for parallel execution of all +// work scheduled for t. +func (fc *fakeClock) removeWait(t time.Time) { + fc.mu.RLock() + defer fc.mu.RUnlock() + + wg := fc.waitGroups[t] + wg.Done() +} + +// advance executes all work that have been scheduled to execute within d from +// the current fake time. Blocks until all work has completed execution. +func (fc *fakeClock) advance(d time.Duration) { + // Block until all the work is done + until := fc.clock.Now().Add(d) + for { + fc.mu.Lock() + if fc.times.Len() == 0 { + fc.mu.Unlock() + return + } + + t := heap.Pop(fc.times).(time.Time) + if t.After(until) { + // No work to do + heap.Push(fc.times, t) + fc.mu.Unlock() + return + } + fc.mu.Unlock() + + diff := t.Sub(fc.clock.Now()) + fc.clock.Advance(diff) + + fc.mu.RLock() + wg := fc.waitGroups[t] + fc.mu.RUnlock() + + wg.Wait() + + fc.mu.Lock() + delete(fc.waitGroups, t) + fc.mu.Unlock() + } +} + +type fakeTimer struct { + clock *fakeClock + timer clockwork.Timer + + mu sync.RWMutex + until time.Time +} + +var _ tcpip.Timer = (*fakeTimer)(nil) + +// Reset implements tcpip.Timer.Reset. +func (ft *fakeTimer) Reset(d time.Duration) { + if !ft.timer.Reset(d) { + return + } + + ft.mu.Lock() + defer ft.mu.Unlock() + + ft.clock.removeWait(ft.until) + ft.until = ft.clock.clock.Now().Add(d) + ft.clock.addWait(ft.until) +} + +// Stop implements tcpip.Timer.Stop. +func (ft *fakeTimer) Stop() bool { + if !ft.timer.Stop() { + return false + } + + ft.mu.RLock() + defer ft.mu.RUnlock() + + ft.clock.removeWait(ft.until) + return true +} + +type timeHeap []time.Time + +var _ heap.Interface = (*timeHeap)(nil) + +func (h timeHeap) Len() int { + return len(h) +} + +func (h timeHeap) Less(i, j int) bool { + return h[i].Before(h[j]) +} + +func (h timeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *timeHeap) Push(x interface{}) { + *h = append(*h, x.(time.Time)) +} + +func (h *timeHeap) Pop() interface{} { + last := (*h)[len(*h)-1] + *h = (*h)[:len(*h)-1] + return last +} diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index a6546cef0..c962693f5 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" ) const ( @@ -120,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {} type fwdTestNetworkProtocol struct { addrCache *linkAddrCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address) + onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) } +var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) + func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return fwdTestNetNumber } @@ -173,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {} func (f *fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { if f.addrCache != nil && f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr) + f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) }) } return nil @@ -301,6 +304,16 @@ func (e *fwdTestLinkEndpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Er // Wait implements stack.LinkEndpoint.Wait. func (*fwdTestLinkEndpoint) Wait() {} +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*fwdTestLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ @@ -394,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any address will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -452,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Only packets to address 3 will be resolved to the // link address "c". if addr == "\x03" { @@ -504,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -548,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, @@ -605,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { // Create a network protocol with a fake resolver. proto := &fwdTestNetworkProtocol{ addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) { + onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { // Any packets will be resolved to the link address "c". cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") }, diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 62d4eb1b6..cbbae4224 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -16,39 +16,49 @@ package stack import ( "fmt" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) -// Table names. +// tableID is an index into IPTables.tables. +type tableID int + const ( - TablenameNat = "nat" - TablenameMangle = "mangle" - TablenameFilter = "filter" + natID tableID = iota + mangleID + filterID + numTables ) -// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +// Table names. const ( - ChainNamePrerouting = "PREROUTING" - ChainNameInput = "INPUT" - ChainNameForward = "FORWARD" - ChainNameOutput = "OUTPUT" - ChainNamePostrouting = "POSTROUTING" + NATTable = "nat" + MangleTable = "mangle" + FilterTable = "filter" ) +// nameToID is immutable. +var nameToID = map[string]tableID{ + NATTable: natID, + MangleTable: mangleID, + FilterTable: filterID, +} + // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 +// reaperDelay is how long to wait before starting to reap connections. +const reaperDelay = 5 * time.Second + // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { - // TODO(gvisor.dev/issue/170): We may be able to swap out some strings for - // iotas. return &IPTables{ - tables: map[string]Table{ - TablenameNat: Table{ + tables: [numTables]Table{ + natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, @@ -56,65 +66,71 @@ func DefaultTables() *IPTables { Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - Underflows: map[Hook]int{ + Underflows: [NumHooks]int{ Prerouting: 0, Input: 1, + Forward: HookUnset, Output: 2, Postrouting: 3, }, - UserChains: map[string]int{}, }, - TablenameMangle: Table{ + mangleID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ + BuiltinChains: [NumHooks]int{ Prerouting: 0, Output: 1, }, - Underflows: map[Hook]int{ - Prerouting: 0, - Output: 1, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, - TablenameFilter: Table{ + filterID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: AcceptTarget{}}, Rule{Target: ErrorTarget{}}, }, - BuiltinChains: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: 0, - Forward: 1, - Output: 2, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, }, }, - priorities: map[Hook][]string{ - Input: []string{TablenameNat, TablenameFilter}, - Prerouting: []string{TablenameMangle, TablenameNat}, - Output: []string{TablenameMangle, TablenameNat, TablenameFilter}, + priorities: [NumHooks][]tableID{ + Prerouting: []tableID{mangleID, natID}, + Input: []tableID{natID, filterID}, + Output: []tableID{mangleID, natID, filterID}, }, - connections: ConnTrackTable{ - CtMap: make(map[uint32]ConnTrackTupleHolder), - Seed: generateRandUint32(), + connections: ConnTrack{ + seed: generateRandUint32(), }, + reaperDone: make(chan struct{}, 1), } } @@ -123,62 +139,59 @@ func DefaultTables() *IPTables { func EmptyFilterTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - Underflows: map[Hook]int{ - Input: HookUnset, - Forward: HookUnset, - Output: HookUnset, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Postrouting: HookUnset, }, - UserChains: map[string]int{}, } } -// EmptyNatTable returns a Table with no rules and the filter table chains +// EmptyNATTable returns a Table with no rules and the filter table chains // mapped to HookUnset. -func EmptyNatTable() Table { +func EmptyNATTable() Table { return Table{ Rules: []Rule{}, - BuiltinChains: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + BuiltinChains: [NumHooks]int{ + Forward: HookUnset, }, - Underflows: map[Hook]int{ - Prerouting: HookUnset, - Input: HookUnset, - Output: HookUnset, - Postrouting: HookUnset, + Underflows: [NumHooks]int{ + Forward: HookUnset, }, - UserChains: map[string]int{}, } } -// GetTable returns table by name. +// GetTable returns a table by name. func (it *IPTables) GetTable(name string) (Table, bool) { + id, ok := nameToID[name] + if !ok { + return Table{}, false + } it.mu.RLock() defer it.mu.RUnlock() - t, ok := it.tables[name] - return t, ok + return it.tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) { +func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { + id, ok := nameToID[name] + if !ok { + return tcpip.ErrInvalidOptionValue + } it.mu.Lock() defer it.mu.Unlock() + // If iptables is being enabled, initialize the conntrack table and + // reaper. + if !it.modified { + it.connections.buckets = make([]bucket, numBuckets) + it.startReaper(reaperDelay) + } it.modified = true - it.tables[name] = table -} - -// GetPriorities returns slice of priorities associated with hook. -func (it *IPTables) GetPriorities(hook Hook) []string { - it.mu.RLock() - defer it.mu.RUnlock() - return it.priorities[hook] + it.tables[id] = table + return nil } // A chainVerdict is what a table decides should be done with a packet. @@ -213,11 +226,19 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Packets are manipulated only if connection and matching // NAT rule exists. - it.connections.HandlePacket(pkt, hook, gso, r) + shouldTrack := it.connections.handlePacket(pkt, hook, gso, r) // Go through each table containing the hook. - for _, tablename := range it.GetPriorities(hook) { - table, _ := it.GetTable(tablename) + it.mu.RLock() + defer it.mu.RUnlock() + priorities := it.priorities[hook] + for _, tableID := range priorities { + // If handlePacket already NATed the packet, we don't need to + // check the NAT table. + if tableID == natID && pkt.NatDone { + continue + } + table := it.tables[tableID] ruleIdx := table.BuiltinChains[hook] switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { // If the table returns Accept, move on to the next table. @@ -246,17 +267,59 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr } } + // If this connection should be tracked, try to add an entry for it. If + // traversing the nat table didn't end in adding an entry, + // maybeInsertNoop will add a no-op entry for the connection. This is + // needeed when establishing connections so that the SYN/ACK reply to an + // outgoing SYN is delivered to the correct endpoint rather than being + // redirected by a prerouting rule. + // + // From the iptables documentation: "If there is no rule, a `null' + // binding is created: this usually does not map the packet, but exists + // to ensure we don't map another stream over an existing one." + if shouldTrack { + it.connections.maybeInsertNoop(pkt, hook) + } + // Every table returned Accept. return true } +// beforeSave is invoked by stateify. +func (it *IPTables) beforeSave() { + // Ensure the reaper exits cleanly. + it.reaperDone <- struct{}{} + // Prevent others from modifying the connection table. + it.connections.mu.Lock() +} + +// afterLoad is invoked by stateify. +func (it *IPTables) afterLoad() { + it.startReaper(reaperDelay) +} + +// startReaper starts a goroutine that wakes up periodically to reap timed out +// connections. +func (it *IPTables) startReaper(interval time.Duration) { + go func() { // S/R-SAFE: reaperDone is signalled when iptables is saved. + bucket := 0 + for { + select { + case <-it.reaperDone: + return + case <-time.After(interval): + bucket, interval = it.connections.reapUnused(bucket, interval) + } + } + }() +} + // CheckPackets runs pkts through the rules for hook and returns a map of packets that // should not go forward. // -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// -// TODO(gvisor.dev/issue/170): pk.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -280,9 +343,9 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * return drop, natPkts } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. @@ -327,23 +390,12 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId return chainDrop } -// Precondition: pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// TODO(gvisor.dev/issue/170): pkt.NetworkHeader will always be set as a -// precondition. +// Preconditions: +// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// - pkt.NetworkHeader is not nil. func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] - // If pkt.NetworkHeader hasn't been set yet, it will be contained in - // pkt.Data. - if pkt.NetworkHeader == nil { - var ok bool - pkt.NetworkHeader, ok = pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { - // Precondition has been violated. - panic(fmt.Sprintf("iptables checks require IPv4 headers of at least %d bytes", header.IPv4MinimumSize)) - } - } - // Check whether the packet matches the IP header filter. if !rule.Filter.match(header.IPv4(pkt.NetworkHeader), hook, nicName) { // Continue on to the next rule. diff --git a/pkg/tcpip/stack/iptables_state.go b/pkg/tcpip/stack/iptables_state.go new file mode 100644 index 000000000..529e02a07 --- /dev/null +++ b/pkg/tcpip/stack/iptables_state.go @@ -0,0 +1,40 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastUsed is invoked by stateify. +func (cn *conn) saveLastUsed() unixTime { + return unixTime{cn.lastUsed.Unix(), cn.lastUsed.UnixNano()} +} + +// loadLastUsed is invoked by stateify. +func (cn *conn) loadLastUsed(unix unixTime) { + cn.lastUsed = time.Unix(unix.second, unix.nano) +} + +// beforeSave is invoked by stateify. +func (ct *ConnTrack) beforeSave() { + ct.mu.Lock() +} diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 92e31643e..dc88033c7 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -24,7 +24,7 @@ import ( type AcceptTarget struct{} // Action implements Target.Action. -func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 } @@ -32,7 +32,7 @@ func (AcceptTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, t type DropTarget struct{} // Action implements Target.Action. -func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } @@ -41,7 +41,7 @@ func (DropTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcp type ErrorTarget struct{} // Action implements Target.Action. -func (ErrorTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -52,7 +52,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -61,7 +61,7 @@ func (UserChainTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route type ReturnTarget struct{} // Action implements Target.Action. -func (ReturnTarget) Action(*PacketBuffer, *ConnTrackTable, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { +func (ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } @@ -92,7 +92,7 @@ type RedirectTarget struct { // TODO(gvisor.dev/issue/170): Parse headers without copying. The current // implementation only works for PREROUTING and calls pkt.Clone(), neither // of which should be the case. -func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -150,12 +150,11 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook return RuleAccept, 0 } - // Set up conection for matching NAT rule. - // Only the first packet of the connection comes here. - // Other packets will be manipulated in connection tracking. - if conn, _ := ct.connTrackForPacket(pkt, hook, true); conn != nil { - ct.SetNatInfo(pkt, rt, hook) - ct.HandlePacket(pkt, hook, gso, r) + // Set up conection for matching NAT rule. Only the first + // packet of the connection comes here. Other packets will be + // manipulated in connection tracking. + if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + ct.handlePacket(pkt, hook, gso, r) } default: return RuleDrop, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7026990c4..73274ada9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -78,50 +78,55 @@ const ( ) // IPTables holds all the tables for a netstack. +// +// +stateify savable type IPTables struct { // mu protects tables, priorities, and modified. mu sync.RWMutex - // tables maps table names to tables. User tables have arbitrary names. - // mu needs to be locked for accessing. - tables map[string]Table + // tables maps tableIDs to tables. Holds builtin tables only, not user + // tables. mu must be locked for accessing. + tables [numTables]Table // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. mu needs to be locked for accessing. - priorities map[Hook][]string + priorities [NumHooks][]tableID // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool - connections ConnTrackTable + connections ConnTrack + + // reaperDone can be signalled to stop the reaper goroutine. + reaperDone chan struct{} } // A Table defines a set of chains and hooks into the network stack. It is // really just a list of rules. +// +// +stateify savable type Table struct { // Rules holds the rules that make up the table. Rules []Rule // BuiltinChains maps builtin chains to their entrypoint rule in Rules. - BuiltinChains map[Hook]int + BuiltinChains [NumHooks]int // Underflows maps builtin chains to their underflow rule in Rules // (i.e. the rule to execute if the chain returns without a verdict). - Underflows map[Hook]int - - // UserChains holds user-defined chains for the keyed by name. Users - // can give their chains arbitrary names. - UserChains map[string]int + Underflows [NumHooks]int } // ValidHooks returns a bitmap of the builtin hooks for the given table. func (table *Table) ValidHooks() uint32 { hooks := uint32(0) - for hook := range table.BuiltinChains { - hooks |= 1 << hook + for hook, ruleIdx := range table.BuiltinChains { + if ruleIdx != HookUnset { + hooks |= 1 << hook + } } return hooks } @@ -130,6 +135,8 @@ func (table *Table) ValidHooks() uint32 { // contains zero or more matchers, each of which is a specification of which // packets this rule applies to. If there are no matchers in the rule, it // applies to any packet. +// +// +stateify savable type Rule struct { // Filter holds basic IP filtering fields common to every rule. Filter IPHeaderFilter @@ -142,6 +149,8 @@ type Rule struct { } // IPHeaderFilter holds basic IP filtering data common to every rule. +// +// +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber @@ -249,5 +258,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(packet *PacketBuffer, connections *ConnTrackTable, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) + Action(packet *PacketBuffer, connections *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 403557fd7..6f73a0ce4 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 1baa498d0..b15b8d1cb 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -48,7 +48,7 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) if f := r.onLinkAddressRequest; f != nil { f() diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go index e28c23d66..5174e639c 100644 --- a/pkg/tcpip/stack/ndp.go +++ b/pkg/tcpip/stack/ndp.go @@ -33,12 +33,6 @@ const ( // Default = 1 (from RFC 4862 section 5.1) defaultDupAddrDetectTransmits = 1 - // defaultRetransmitTimer is the default amount of time to wait between - // sending NDP Neighbor solicitation messages. - // - // Default = 1s (from RFC 4861 section 10). - defaultRetransmitTimer = time.Second - // defaultMaxRtrSolicitations is the default number of Router // Solicitation messages to send when a NIC becomes enabled. // @@ -79,16 +73,6 @@ const ( // Default = true. defaultAutoGenGlobalAddresses = true - // minimumRetransmitTimer is the minimum amount of time to wait between - // sending NDP Neighbor solicitation messages. Note, RFC 4861 does - // not impose a minimum Retransmit Timer, but we do here to make sure - // the messages are not sent all at once. We also come to this value - // because in the RetransmitTimer field of a Router Advertisement, a - // value of 0 means unspecified, so the smallest valid value is 1. - // Note, the unit of the RetransmitTimer field in the Router - // Advertisement is milliseconds. - minimumRetransmitTimer = time.Millisecond - // minimumRtrSolicitationInterval is the minimum amount of time to wait // between sending Router Solicitation messages. This limit is imposed // to make sure that Router Solicitation messages are not sent all at @@ -469,7 +453,7 @@ type ndpState struct { rtrSolicit struct { // The timer used to send the next router solicitation message. - timer *time.Timer + timer tcpip.Timer // Used to let the Router Solicitation timer know that it has been stopped. // @@ -503,7 +487,7 @@ type ndpState struct { // to the DAD goroutine that DAD should stop. type dadState struct { // The DAD timer to send the next NS message, or resolve the address. - timer *time.Timer + timer tcpip.Timer // Used to let the DAD timer know that it has been stopped. // @@ -515,38 +499,38 @@ type dadState struct { // defaultRouterState holds data associated with a default router discovered by // a Router Advertisement (RA). type defaultRouterState struct { - // Timer to invalidate the default router. + // Job to invalidate the default router. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // onLinkPrefixState holds data associated with an on-link prefix discovered by // a Router Advertisement's Prefix Information option (PI) when the NDP // configurations was configured to do so. type onLinkPrefixState struct { - // Timer to invalidate the on-link prefix. + // Job to invalidate the on-link prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job } // tempSLAACAddrState holds state associated with a temporary SLAAC address. type tempSLAACAddrState struct { - // Timer to deprecate the temporary SLAAC address. + // Job to deprecate the temporary SLAAC address. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the temporary SLAAC address. + // Job to invalidate the temporary SLAAC address. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job - // Timer to regenerate the temporary SLAAC address. + // Job to regenerate the temporary SLAAC address. // // Must not be nil. - regenTimer *tcpip.CancellableTimer + regenJob *tcpip.Job createdAt time.Time @@ -561,15 +545,15 @@ type tempSLAACAddrState struct { // slaacPrefixState holds state associated with a SLAAC prefix. type slaacPrefixState struct { - // Timer to deprecate the prefix. + // Job to deprecate the prefix. // // Must not be nil. - deprecationTimer *tcpip.CancellableTimer + deprecationJob *tcpip.Job - // Timer to invalidate the prefix. + // Job to invalidate the prefix. // // Must not be nil. - invalidationTimer *tcpip.CancellableTimer + invalidationJob *tcpip.Job // Nonzero only when the address is not valid forever. validUntil time.Time @@ -651,12 +635,12 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, ref *ref } var done bool - var timer *time.Timer + var timer tcpip.Timer // We initially start a timer to fire immediately because some of the DAD work // cannot be done while holding the NIC's lock. This is effectively the same // as starting a goroutine but we use a timer that fires immediately so we can // reset it for the next DAD iteration. - timer = time.AfterFunc(0, func() { + timer = ndp.nic.stack.Clock().AfterFunc(0, func() { ndp.nic.mu.Lock() defer ndp.nic.mu.Unlock() @@ -871,9 +855,9 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { case ok && rl != 0: // This is an already discovered default router. Update - // the invalidation timer. - rtr.invalidationTimer.StopLocked() - rtr.invalidationTimer.Reset(rl) + // the invalidation job. + rtr.invalidationJob.Cancel() + rtr.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = rtr case ok && rl == 0: @@ -950,7 +934,7 @@ func (ndp *ndpState) invalidateDefaultRouter(ip tcpip.Address) { return } - rtr.invalidationTimer.StopLocked() + rtr.invalidationJob.Cancel() delete(ndp.defaultRouters, ip) // Let the integrator know a discovered default router is invalidated. @@ -979,12 +963,12 @@ func (ndp *ndpState) rememberDefaultRouter(ip tcpip.Address, rl time.Duration) { } state := defaultRouterState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateDefaultRouter(ip) }), } - state.invalidationTimer.Reset(rl) + state.invalidationJob.Schedule(rl) ndp.defaultRouters[ip] = state } @@ -1009,13 +993,13 @@ func (ndp *ndpState) rememberOnLinkPrefix(prefix tcpip.Subnet, l time.Duration) } state := onLinkPrefixState{ - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { ndp.invalidateOnLinkPrefix(prefix) }), } if l < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(l) + state.invalidationJob.Schedule(l) } ndp.onLinkPrefixes[prefix] = state @@ -1033,7 +1017,7 @@ func (ndp *ndpState) invalidateOnLinkPrefix(prefix tcpip.Subnet) { return } - s.invalidationTimer.StopLocked() + s.invalidationJob.Cancel() delete(ndp.onLinkPrefixes, prefix) // Let the integrator know a discovered on-link prefix is invalidated. @@ -1082,14 +1066,14 @@ func (ndp *ndpState) handleOnLinkPrefixInformation(pi header.NDPPrefixInformatio // This is an already discovered on-link prefix with a // new non-zero valid lifetime. // - // Update the invalidation timer. + // Update the invalidation job. - prefixState.invalidationTimer.StopLocked() + prefixState.invalidationJob.Cancel() if vl < header.NDPInfiniteLifetime { - // Prefix is valid for a finite lifetime, reset the timer to expire after + // Prefix is valid for a finite lifetime, schedule the job to execute after // the new valid lifetime. - prefixState.invalidationTimer.Reset(vl) + prefixState.invalidationJob.Schedule(vl) } ndp.onLinkPrefixes[prefix] = prefixState @@ -1154,7 +1138,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { } state := slaacPrefixState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the deprecated SLAAC prefix %s", prefix)) @@ -1162,7 +1146,7 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { ndp.deprecateSLAACAddress(state.stableAddr.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { state, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for the invalidated SLAAC prefix %s", prefix)) @@ -1184,19 +1168,19 @@ func (ndp *ndpState) doSLAAC(prefix tcpip.Subnet, pl, vl time.Duration) { if !ndp.generateSLAACAddr(prefix, &state) { // We were unable to generate an address for the prefix, we do not nothing - // further as there is no reason to maintain state or timers for a prefix we + // further as there is no reason to maintain state or jobs for a prefix we // do not have an address for. return } - // Setup the initial timers to deprecate and invalidate prefix. + // Setup the initial jobs to deprecate and invalidate prefix. if pl < header.NDPInfiniteLifetime && pl != 0 { - state.deprecationTimer.Reset(pl) + state.deprecationJob.Schedule(pl) } if vl < header.NDPInfiniteLifetime { - state.invalidationTimer.Reset(vl) + state.invalidationJob.Schedule(vl) state.validUntil = now.Add(vl) } @@ -1428,7 +1412,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla } state := tempSLAACAddrState{ - deprecationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + deprecationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to deprecate temporary address %s", prefix, generatedAddr)) @@ -1441,7 +1425,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.deprecateSLAACAddress(tempAddrState.ref) }), - invalidationTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + invalidationJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to invalidate temporary address %s", prefix, generatedAddr)) @@ -1454,7 +1438,7 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, generatedAddr.Address, tempAddrState) }), - regenTimer: tcpip.NewCancellableTimer(&ndp.nic.mu, func() { + regenJob: ndp.nic.stack.newJob(&ndp.nic.mu, func() { prefixState, ok := ndp.slaacPrefixes[prefix] if !ok { panic(fmt.Sprintf("ndp: must have a slaacPrefixes entry for %s to regenerate temporary address after %s", prefix, generatedAddr)) @@ -1481,9 +1465,9 @@ func (ndp *ndpState) generateTempSLAACAddr(prefix tcpip.Subnet, prefixState *sla ref: ref, } - state.deprecationTimer.Reset(pl) - state.invalidationTimer.Reset(vl) - state.regenTimer.Reset(pl - ndp.configs.RegenAdvanceDuration) + state.deprecationJob.Schedule(pl) + state.invalidationJob.Schedule(vl) + state.regenJob.Schedule(pl - ndp.configs.RegenAdvanceDuration) prefixState.generationAttempts++ prefixState.tempAddrs[generatedAddr.Address] = state @@ -1518,16 +1502,16 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat prefixState.stableAddr.ref.deprecated = false } - // If prefix was preferred for some finite lifetime before, stop the - // deprecation timer so it can be reset. - prefixState.deprecationTimer.StopLocked() + // If prefix was preferred for some finite lifetime before, cancel the + // deprecation job so it can be reset. + prefixState.deprecationJob.Cancel() now := time.Now() - // Reset the deprecation timer if prefix has a finite preferred lifetime. + // Schedule the deprecation job if prefix has a finite preferred lifetime. if pl < header.NDPInfiniteLifetime { if !deprecated { - prefixState.deprecationTimer.Reset(pl) + prefixState.deprecationJob.Schedule(pl) } prefixState.preferredUntil = now.Add(pl) } else { @@ -1546,9 +1530,9 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // 3) Otherwise, reset the valid lifetime of the prefix to 2 hours. if vl >= header.NDPInfiniteLifetime { - // Handle the infinite valid lifetime separately as we do not keep a timer - // in this case. - prefixState.invalidationTimer.StopLocked() + // Handle the infinite valid lifetime separately as we do not schedule a + // job in this case. + prefixState.invalidationJob.Cancel() prefixState.validUntil = time.Time{} } else { var effectiveVl time.Duration @@ -1569,8 +1553,8 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } if effectiveVl != 0 { - prefixState.invalidationTimer.StopLocked() - prefixState.invalidationTimer.Reset(effectiveVl) + prefixState.invalidationJob.Cancel() + prefixState.invalidationJob.Schedule(effectiveVl) prefixState.validUntil = now.Add(effectiveVl) } } @@ -1582,7 +1566,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // Note, we do not need to update the entries in the temporary address map - // after updating the timers because the timers are held as pointers. + // after updating the jobs because the jobs are held as pointers. var regenForAddr tcpip.Address allAddressesRegenerated := true for tempAddr, tempAddrState := range prefixState.tempAddrs { @@ -1596,14 +1580,14 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer valid, invalidate it immediately. Otherwise, - // reset the invalidation timer. + // reset the invalidation job. newValidLifetime := validUntil.Sub(now) if newValidLifetime <= 0 { ndp.invalidateTempSLAACAddr(prefixState.tempAddrs, tempAddr, tempAddrState) continue } - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.invalidationTimer.Reset(newValidLifetime) + tempAddrState.invalidationJob.Cancel() + tempAddrState.invalidationJob.Schedule(newValidLifetime) // As per RFC 4941 section 3.3 step 4, the preferred lifetime of a temporary // address is the lower of the preferred lifetime of the stable address or @@ -1616,17 +1600,17 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat } // If the address is no longer preferred, deprecate it immediately. - // Otherwise, reset the deprecation timer. + // Otherwise, schedule the deprecation job again. newPreferredLifetime := preferredUntil.Sub(now) - tempAddrState.deprecationTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() if newPreferredLifetime <= 0 { ndp.deprecateSLAACAddress(tempAddrState.ref) } else { tempAddrState.ref.deprecated = false - tempAddrState.deprecationTimer.Reset(newPreferredLifetime) + tempAddrState.deprecationJob.Schedule(newPreferredLifetime) } - tempAddrState.regenTimer.StopLocked() + tempAddrState.regenJob.Cancel() if tempAddrState.regenerated { } else { allAddressesRegenerated = false @@ -1637,7 +1621,7 @@ func (ndp *ndpState) refreshSLAACPrefixLifetimes(prefix tcpip.Subnet, prefixStat // immediately after we finish iterating over the temporary addresses. regenForAddr = tempAddr } else { - tempAddrState.regenTimer.Reset(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) + tempAddrState.regenJob.Schedule(newPreferredLifetime - ndp.configs.RegenAdvanceDuration) } } } @@ -1717,7 +1701,7 @@ func (ndp *ndpState) cleanupSLAACAddrResourcesAndNotify(addr tcpip.AddressWithPr ndp.cleanupSLAACPrefixResources(prefix, state) } -// cleanupSLAACPrefixResources cleansup a SLAAC prefix's timers and entry. +// cleanupSLAACPrefixResources cleans up a SLAAC prefix's jobs and entry. // // Panics if the SLAAC prefix is not known. // @@ -1729,8 +1713,8 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa } state.stableAddr.ref = nil - state.deprecationTimer.StopLocked() - state.invalidationTimer.StopLocked() + state.deprecationJob.Cancel() + state.invalidationJob.Cancel() delete(ndp.slaacPrefixes, prefix) } @@ -1775,13 +1759,13 @@ func (ndp *ndpState) cleanupTempSLAACAddrResourcesAndNotify(addr tcpip.AddressWi } // cleanupTempSLAACAddrResourcesAndNotify cleans up a temporary SLAAC address's -// timers and entry. +// jobs and entry. // // The NIC that ndp belongs to MUST be locked. func (ndp *ndpState) cleanupTempSLAACAddrResources(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { - tempAddrState.deprecationTimer.StopLocked() - tempAddrState.invalidationTimer.StopLocked() - tempAddrState.regenTimer.StopLocked() + tempAddrState.deprecationJob.Cancel() + tempAddrState.invalidationJob.Cancel() + tempAddrState.regenJob.Cancel() delete(tempAddrs, tempAddr) } @@ -1860,7 +1844,7 @@ func (ndp *ndpState) startSolicitingRouters() { var done bool ndp.rtrSolicit.done = &done - ndp.rtrSolicit.timer = time.AfterFunc(delay, func() { + ndp.rtrSolicit.timer = ndp.nic.stack.Clock().AfterFunc(delay, func() { ndp.nic.mu.Lock() if done { // If we reach this point, it means that the RS timer fired after another diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 6f86abc98..644ba7c33 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1254,7 +1254,7 @@ func TestRouterDiscovery(t *testing.T) { default: } - // Wait for lladdr2's router invalidation timer to fire. The lifetime + // Wait for lladdr2's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1271,7 +1271,7 @@ func TestRouterDiscovery(t *testing.T) { e.InjectInbound(header.IPv6ProtocolNumber, raBuf(llAddr2, 0)) expectRouterEvent(llAddr2, false) - // Wait for lladdr3's router invalidation timer to fire. The lifetime + // Wait for lladdr3's router invalidation job to execute. The lifetime // of the router should have been updated to the most recent (smaller) // lifetime. // @@ -1502,7 +1502,7 @@ func TestPrefixDiscovery(t *testing.T) { default: } - // Wait for prefix2's most recent invalidation timer plus some buffer to + // Wait for prefix2's most recent invalidation job plus some buffer to // expire. select { case e := <-ndpDisp.prefixC: @@ -2395,7 +2395,7 @@ func TestAutoGenTempAddrRegen(t *testing.T) { for _, addr := range tempAddrs { // Wait for a deprecation then invalidation event, or just an invalidation // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation timers could fire in any + // cases because the deprecation and invalidation jobs could execute in any // order. select { case e := <-ndpDisp.autoGenAddrC: @@ -2432,9 +2432,9 @@ func TestAutoGenTempAddrRegen(t *testing.T) { } } -// TestAutoGenTempAddrRegenTimerUpdates tests that a temporary address's -// regeneration timer gets updated when refreshing the address's lifetimes. -func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { +// TestAutoGenTempAddrRegenJobUpdates tests that a temporary address's +// regeneration job gets updated when refreshing the address's lifetimes. +func TestAutoGenTempAddrRegenJobUpdates(t *testing.T) { const ( nicID = 1 regenAfter = 2 * time.Second @@ -2533,7 +2533,7 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { // // A new temporary address should immediately be generated since the // regeneration time has already passed since the last address was generated - // - this regeneration does not depend on a timer. + // - this regeneration does not depend on a job. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(tempAddr2, newAddr) @@ -2559,11 +2559,11 @@ func TestAutoGenTempAddrRegenTimerUpdates(t *testing.T) { } // Set the maximum lifetimes for temporary addresses such that on the next - // RA, the regeneration timer gets reset. + // RA, the regeneration job gets scheduled again. // // The maximum lifetime is the sum of the minimum lifetimes for temporary // addresses + the time that has already passed since the last address was - // generated so that the regeneration timer is needed to generate the next + // generated so that the regeneration job is needed to generate the next // address. newLifetimes := newMinVLDuration + regenAfter + defaultAsyncNegativeEventTimeout ndpConfigs.MaxTempAddrValidLifetime = newLifetimes @@ -2993,9 +2993,9 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) { expectPrimaryAddr(addr2) } -// TestAutoGenAddrTimerDeprecation tests that an address is properly deprecated +// TestAutoGenAddrJobDeprecation tests that an address is properly deprecated // when its preferred lifetime expires. -func TestAutoGenAddrTimerDeprecation(t *testing.T) { +func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second @@ -3513,8 +3513,8 @@ func TestAutoGenAddrRemoval(t *testing.T) { } expectAutoGenAddrEvent(addr, invalidatedAddr) - // Wait for the original valid lifetime to make sure the original timer - // got stopped/cleaned up. + // Wait for the original valid lifetime to make sure the original job got + // cancelled/cleaned up. select { case <-ndpDisp.autoGenAddrC: t.Fatal("unexpectedly received an auto gen addr event") diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go new file mode 100644 index 000000000..1d37716c2 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -0,0 +1,335 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const neighborCacheSize = 512 // max entries per interface + +// neighborCache maps IP addresses to link addresses. It uses the Least +// Recently Used (LRU) eviction strategy to implement a bounded cache for +// dynmically acquired entries. It contains the state machine and configuration +// for running Neighbor Unreachability Detection (NUD). +// +// There are two types of entries in the neighbor cache: +// 1. Dynamic entries are discovered automatically by neighbor discovery +// protocols (e.g. ARP, NDP). These protocols will attempt to reconfirm +// reachability with the device once the entry's state becomes Stale. +// 2. Static entries are explicitly added by a user and have no expiration. +// Their state is always Static. The amount of static entries stored in the +// cache is unbounded. +// +// neighborCache implements NUDHandler. +type neighborCache struct { + nic *NIC + state *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + cache map[tcpip.Address]*neighborEntry + dynamic struct { + lru neighborEntryList + + // count tracks the amount of dynamic entries in the cache. This is + // needed since static entries do not count towards the LRU cache + // eviction strategy. + count uint16 + } +} + +var _ NUDHandler = (*neighborCache)(nil) + +// getOrCreateEntry retrieves a cache entry associated with addr. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[remoteAddr]; ok { + entry.mu.RLock() + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.lru.PushFront(entry) + } + entry.mu.RUnlock() + return entry + } + + // The entry that needs to be created must be dynamic since all static + // entries are directly added to the cache via addStaticEntry. + entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + if n.dynamic.count == neighborCacheSize { + e := n.dynamic.lru.Back() + e.mu.Lock() + + delete(n.cache, e.neigh.Addr) + n.dynamic.lru.Remove(e) + n.dynamic.count-- + + e.dispatchRemoveEventLocked() + e.setStateLocked(Unknown) + e.notifyWakersLocked() + e.mu.Unlock() + } + n.cache[remoteAddr] = entry + n.dynamic.lru.PushFront(entry) + n.dynamic.count++ + return entry +} + +// entry looks up the neighbor cache for translating address to link address +// (e.g. IP -> MAC). If the LinkEndpoint requests address resolution and there +// is a LinkAddressResolver registered with the network protocol, the cache +// attempts to resolve the address and returns ErrWouldBlock. If a Waker is +// provided, it will be notified when address resolution is complete (success +// or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification +// channel is returned for the top level caller to block. Channel is closed +// once address resolution is complete (success or not). +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { + if linkRes != nil { + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + } + return e, nil, nil + } + } + + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry.mu.Lock() + defer entry.mu.Unlock() + + switch s := entry.neigh.State; s { + case Reachable, Static: + return entry.neigh, nil, nil + + case Unknown, Incomplete, Stale, Delay, Probe: + entry.addWakerLocked(w) + + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.neigh, nil, tcpip.ErrNoLinkAddress + } + entry.done = make(chan struct{}) + } + + entry.handlePacketQueuedLocked() + return entry.neigh, entry.done, tcpip.ErrWouldBlock + + case Failed: + return entry.neigh, nil, tcpip.ErrNoLinkAddress + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", s)) + } +} + +// removeWaker removes a waker that has been added when link resolution for +// addr was requested. +func (n *neighborCache) removeWaker(addr tcpip.Address, waker *sleep.Waker) { + n.mu.Lock() + if entry, ok := n.cache[addr]; ok { + delete(entry.wakers, waker) + } + n.mu.Unlock() +} + +// entries returns all entries in the neighbor cache. +func (n *neighborCache) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(n.cache)) + n.mu.RLock() + for _, entry := range n.cache { + entry.mu.RLock() + entries = append(entries, entry.neigh) + entry.mu.RUnlock() + } + n.mu.RUnlock() + return entries +} + +// addStaticEntry adds a static entry to the neighbor cache, mapping an IP +// address to a link address. If a dynamic entry exists in the neighbor cache +// with the same address, it will be replaced with this static entry. If a +// static entry exists with the same address but different link address, it +// will be updated with the new link address. If a static entry exists with the +// same address and link address, nothing will happen. +func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAddress) { + n.mu.Lock() + defer n.mu.Unlock() + + if entry, ok := n.cache[addr]; ok { + entry.mu.Lock() + if entry.neigh.State != Static { + // Dynamic entry found with the same address. + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } else if entry.neigh.LinkAddr == linkAddr { + // Static entry found with the same address and link address. + entry.mu.Unlock() + return + } else { + // Static entry found with the same address but different link address. + entry.neigh.LinkAddr = linkAddr + entry.dispatchChangeEventLocked(entry.neigh.State) + entry.mu.Unlock() + return + } + + // Notify that resolution has been interrupted, just in case the entry was + // in the Incomplete or Probe state. + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) + n.cache[addr] = entry +} + +// removeEntryLocked removes the specified entry from the neighbor cache. +func (n *neighborCache) removeEntryLocked(entry *neighborEntry) { + if entry.neigh.State != Static { + n.dynamic.lru.Remove(entry) + n.dynamic.count-- + } + if entry.neigh.State != Failed { + entry.dispatchRemoveEventLocked() + } + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + + delete(n.cache, entry.neigh.Addr) +} + +// removeEntry removes a dynamic or static entry by address from the neighbor +// cache. Returns true if the entry was found and deleted. +func (n *neighborCache) removeEntry(addr tcpip.Address) bool { + n.mu.Lock() + defer n.mu.Unlock() + + entry, ok := n.cache[addr] + if !ok { + return false + } + + entry.mu.Lock() + defer entry.mu.Unlock() + + n.removeEntryLocked(entry) + return true +} + +// clear removes all dynamic and static entries from the neighbor cache. +func (n *neighborCache) clear() { + n.mu.Lock() + defer n.mu.Unlock() + + for _, entry := range n.cache { + entry.mu.Lock() + entry.dispatchRemoveEventLocked() + entry.setStateLocked(Unknown) + entry.notifyWakersLocked() + entry.mu.Unlock() + } + + n.dynamic.lru = neighborEntryList{} + n.cache = make(map[tcpip.Address]*neighborEntry) + n.dynamic.count = 0 +} + +// config returns the NUD configuration. +func (n *neighborCache) config() NUDConfigurations { + return n.state.Config() +} + +// setConfig changes the NUD configuration. +// +// If config contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *neighborCache) setConfig(config NUDConfigurations) { + config.resetInvalidFields() + n.state.SetConfig(config) +} + +// HandleProbe implements NUDHandler.HandleProbe by following the logic defined +// in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled +// by the caller. +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) + entry.mu.Lock() + entry.handleProbeLocked(remoteLinkAddr) + entry.mu.Unlock() +} + +// HandleConfirmation implements NUDHandler.HandleConfirmation by following the +// logic defined in RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce cryptographically +// generated addresses, as defined in RFC 3972, Cryptographically Generated +// Addresses (CGA). This ensures that the claimed source of an NDP message is +// the owner of the claimed address. +func (n *neighborCache) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleConfirmationLocked(linkAddr, flags) + entry.mu.Unlock() + } + // The confirmation SHOULD be silently discarded if the recipient did not + // initiate any communication with the target. This is indicated if there is + // no matching entry for the remote address. +} + +// HandleUpperLevelConfirmation implements +// NUDHandler.HandleUpperLevelConfirmation by following the logic defined in +// RFC 4861 section 7.3.1. +func (n *neighborCache) HandleUpperLevelConfirmation(addr tcpip.Address) { + n.mu.RLock() + entry, ok := n.cache[addr] + n.mu.RUnlock() + if ok { + entry.mu.Lock() + entry.handleUpperLevelConfirmationLocked() + entry.mu.Unlock() + } +} diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go new file mode 100644 index 000000000..4cb2c9c6b --- /dev/null +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -0,0 +1,1752 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "bytes" + "encoding/binary" + "fmt" + "math" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // entryStoreSize is the default number of entries that will be generated and + // added to the entry store. This number needs to be larger than the size of + // the neighbor cache to give ample opportunity for verifying behavior during + // cache overflows. Four times the size of the neighbor cache allows for + // three complete cache overflows. + entryStoreSize = 4 * neighborCacheSize + + // typicalLatency is the typical latency for an ARP or NDP packet to travel + // to a router and back. + typicalLatency = time.Millisecond + + // testEntryBroadcastAddr is a special address that indicates a packet should + // be sent to all nodes. + testEntryBroadcastAddr = tcpip.Address("broadcast") + + // testEntryLocalAddr is the source address of neighbor probes. + testEntryLocalAddr = tcpip.Address("local_addr") + + // testEntryBroadcastLinkAddr is a special link address sent back to + // multicast neighbor probes. + testEntryBroadcastLinkAddr = tcpip.LinkAddress("mac_broadcast") + + // infiniteDuration indicates that a task will not occur in our lifetime. + infiniteDuration = time.Duration(math.MaxInt64) +) + +// entryDiffOpts returns the options passed to cmp.Diff to compare neighbor +// entries. The UpdatedAt field is ignored due to a lack of a deterministic +// method to predict the time that an event will be dispatched. +func entryDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + } +} + +// entryDiffOptsWithSort is like entryDiffOpts but also includes an option to +// sort slices of entries for cases where ordering must be ignored. +func entryDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { + config.resetInvalidFields() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return &neighborCache{ + nic: &NIC{ + stack: &Stack{ + clock: clock, + nudDisp: nudDisp, + }, + id: 1, + }, + state: NewNUDState(config, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } +} + +// testEntryStore contains a set of IP to NeighborEntry mappings. +type testEntryStore struct { + mu sync.RWMutex + entriesMap map[tcpip.Address]NeighborEntry +} + +func toAddress(i int) tcpip.Address { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint16(i)) + return tcpip.Address(buf.String()) +} + +func toLinkAddress(i int) tcpip.LinkAddress { + buf := new(bytes.Buffer) + binary.Write(buf, binary.BigEndian, uint8(1)) + binary.Write(buf, binary.BigEndian, uint8(0)) + binary.Write(buf, binary.BigEndian, uint32(i)) + return tcpip.LinkAddress(buf.String()) +} + +// newTestEntryStore returns a testEntryStore pre-populated with entries. +func newTestEntryStore() *testEntryStore { + store := &testEntryStore{ + entriesMap: make(map[tcpip.Address]NeighborEntry), + } + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + linkAddr := toLinkAddress(i) + + store.entriesMap[addr] = NeighborEntry{ + Addr: addr, + LocalAddr: testEntryLocalAddr, + LinkAddr: linkAddr, + } + } + return store +} + +// size returns the number of entries in the store. +func (s *testEntryStore) size() int { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.entriesMap) +} + +// entry returns the entry at index i. Returns an empty entry and false if i is +// out of bounds. +func (s *testEntryStore) entry(i int) (NeighborEntry, bool) { + return s.entryByAddr(toAddress(i)) +} + +// entryByAddr returns the entry matching addr for situations when the index is +// not available. Returns an empty entry and false if no entries match addr. +func (s *testEntryStore) entryByAddr(addr tcpip.Address) (NeighborEntry, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.entriesMap[addr] + return entry, ok +} + +// entries returns all entries in the store. +func (s *testEntryStore) entries() []NeighborEntry { + entries := make([]NeighborEntry, 0, len(s.entriesMap)) + s.mu.RLock() + defer s.mu.RUnlock() + for i := 0; i < entryStoreSize; i++ { + addr := toAddress(i) + if entry, ok := s.entriesMap[addr]; ok { + entries = append(entries, entry) + } + } + return entries +} + +// set modifies the link addresses of an entry. +func (s *testEntryStore) set(i int, linkAddr tcpip.LinkAddress) { + addr := toAddress(i) + s.mu.Lock() + defer s.mu.Unlock() + if entry, ok := s.entriesMap[addr]; ok { + entry.LinkAddr = linkAddr + s.entriesMap[addr] = entry + } +} + +// testNeighborResolver implements LinkAddressResolver to emulate sending a +// neighbor probe. +type testNeighborResolver struct { + clock tcpip.Clock + neigh *neighborCache + entries *testEntryStore + delay time.Duration + onLinkAddressRequest func() +} + +var _ LinkAddressResolver = (*testNeighborResolver)(nil) + +func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + // Delay handling the request to emulate network latency. + r.clock.AfterFunc(r.delay, func() { + r.fakeRequest(addr) + }) + + // Execute post address resolution action, if available. + if f := r.onLinkAddressRequest; f != nil { + f() + } + return nil +} + +// fakeRequest emulates handling a response for a link address request. +func (r *testNeighborResolver) fakeRequest(addr tcpip.Address) { + if entry, ok := r.entries.entryByAddr(addr); ok { + r.neigh.HandleConfirmation(addr, entry.LinkAddr, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } +} + +func (*testNeighborResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == testEntryBroadcastAddr { + return testEntryBroadcastLinkAddr, true + } + return "", false +} + +func (*testNeighborResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 0 +} + +type entryEvent struct { + nicID tcpip.NICID + address tcpip.Address + linkAddr tcpip.LinkAddress + state NeighborState +} + +func TestNeighborCacheGetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheSetConfig(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + neigh.setConfig(c) + + if got, want := neigh.config(), c; got != want { + t.Errorf("got neigh.config() = %+v, want = %+v", got, want) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheEntry(t *testing.T) { + c := DefaultNUDConfigurations() + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a +// LinkAddressResolver returns ErrNoLinkAddress. +func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { + nudDisp := testNUDDispatcher{} + c := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, c, clock) + store := newTestEntryStore() + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) + if err != tcpip.ErrNoLinkAddress { + t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveEntry(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } +} + +type testContext struct { + clock *fakeClock + neigh *neighborCache + store *testEntryStore + linkRes *testNeighborResolver + nudDisp *testNUDDispatcher +} + +func newTestContext(c NUDConfigurations) testContext { + nudDisp := &testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(nudDisp, c, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + return testContext{ + clock: clock, + neigh: neigh, + store: store, + linkRes: linkRes, + nudDisp: nudDisp, + } +} + +type overflowOptions struct { + startAtEntryIndex int + wantStaticEntries []NeighborEntry +} + +func (c *testContext) overflowCache(opts overflowOptions) error { + // Fill the neighbor cache to capacity to verify the LRU eviction strategy is + // working properly after the entry removal. + for i := opts.startAtEntryIndex; i < c.store.size(); i++ { + // Add a new entry + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + + var wantEvents []testEntryEventInfo + + // When beyond the full capacity, the cache will evict an entry as per the + // LRU eviction strategy. Note that the number of static entries should not + // affect the total number of dynamic entries that can be added. + if i >= neighborCacheSize+opts.startAtEntryIndex { + removedEntry, ok := c.store.entry(i - neighborCacheSize) + if !ok { + return fmt.Errorf("store.entry(%d) not found", i-neighborCacheSize) + } + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }) + } + + wantEvents = append(wantEvents, testEntryEventInfo{ + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, testEntryEventInfo{ + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }) + + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the most recent entries. The order of entries reported + // by entries() is undeterministic, so entries have to be sorted before + // comparison. + wantUnsortedEntries := opts.wantStaticEntries + for i := c.store.size() - neighborCacheSize; i < c.store.size(); i++ { + entry, ok := c.store.entry(i) + if !ok { + return fmt.Errorf("c.store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(c.neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + return fmt.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + return fmt.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + return nil +} + +// TestNeighborCacheOverflow verifies that the LRU cache eviction strategy +// respects the dynamic entry count. +func TestNeighborCacheOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheRemoveEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when an entry is removed. +func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(c.neigh.config().RetransmitTimer) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the entry + c.neigh.removeEntry(entry.Addr) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress verifies that +// adding a duplicate static entry with the same link address does not dispatch +// any events. +func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + + // No more events should have been dispatched. + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress verifies that +// adding a duplicate static entry with a different link address dispatches a +// change event. +func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { + config := DefaultNUDConfigurations() + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a duplicate entry with a different link address + staticLinkAddr += "duplicate" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + defer c.nudDisp.mu.Unlock() + if diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +// TestNeighborCacheRemoveStaticEntryThenOverflow verifies that the LRU cache +// eviction strategy respects the dynamic entry count when a static entry is +// added then removed. In this case, the dynamic entry count shouldn't have +// been touched. +func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a static entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Remove the static entry that was just added + c.neigh.removeEntry(entry.Addr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +// TestNeighborCacheOverwriteWithStaticEntryThenOverflow verifies that the LRU +// cache eviction strategy keeps count of the dynamic entry count when an entry +// is overwritten by a static entry. Static entries should not count towards +// the size of the LRU cache. +func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Override the entry with a static one using the same address + staticLinkAddr := entry.LinkAddr + "static" + c.neigh.addStaticEntry(entry.Addr, staticLinkAddr) + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: staticLinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheNotifiesWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + id, ok := s.Fetch(false /* block */) + if !ok { + t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + if id != wakerID { + t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheRemoveWaker(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + w := sleep.Waker{} + s := sleep.Sleeper{} + const wakerID = 1 + s.AddWaker(&w, wakerID) + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh == nil { + t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + } + + // Remove the waker before the neighbor cache has the opportunity to send a + // notification. + neigh.removeWaker(entry.Addr, &w) + clock.advance(typicalLatency) + + select { + case <-doneCh: + default: + t.Fatal("expected notification from done channel") + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Errorf("unexpected notification from waker with id %d", id) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) + e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + if err != nil { + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + opts := overflowOptions{ + startAtEntryIndex: 1, + wantStaticEntries: []NeighborEntry{ + { + Addr: entry.Addr, + LocalAddr: "", // static entries don't need a local address + LinkAddr: entry.LinkAddr, + State: Static, + }, + }, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheClear(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add a dynamic entry. + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Add a static entry. + neigh.addStaticEntry(entryTestAddr1, entryTestLinkAddr1) + + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Clear shoud remove both dynamic and static entries. + neigh.clear() + + // Remove events dispatched from clear() have no deterministic order so they + // need to be sorted beforehand. + wantUnsortedEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, + } + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, wantUnsortedEvents, eventDiffOptsWithSort()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +// TestNeighborCacheClearThenOverflow verifies that the LRU cache eviction +// strategy keeps count of the dynamic entry count when all entries are +// cleared. +func TestNeighborCacheClearThenOverflow(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + c := newTestContext(config) + + // Add a dynamic entry + entry, ok := c.store.entry(0) + if !ok { + t.Fatalf("c.store.entry(0) not found") + } + _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + c.clock.advance(typicalLatency) + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + + // Clear the cache. + c.neigh.clear() + { + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + c.nudDisp.mu.Lock() + diff := cmp.Diff(c.nudDisp.events, wantEvents, eventDiffOpts()...) + c.nudDisp.events = nil + c.nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + opts := overflowOptions{ + startAtEntryIndex: 0, + } + if err := c.overflowCache(opts); err != nil { + t.Errorf("c.overflowCache(%+v): %s", opts, err) + } +} + +func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { + config := DefaultNUDConfigurations() + // Stay in Reachable so the cache can overflow + config.BaseReachableTime = infiniteDuration + config.MinRandomFactor = 1 + config.MaxRandomFactor = 1 + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + frequentlyUsedEntry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + + // The following logic is very similar to overflowCache, but + // periodically refreshes the frequently used entry. + + // Fill the neighbor cache to capacity + for i := 0; i < neighborCacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Keep adding more entries + for i := neighborCacheSize; i < store.size(); i++ { + // Periodically refresh the frequently used entry + if i%(neighborCacheSize/2) == 0 { + _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + } + } + + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // An entry should have been removed, as per the LRU eviction strategy + removedEntry, ok := store.entry(i - neighborCacheSize + 1) + if !ok { + t.Fatalf("store.entry(%d) not found", i-neighborCacheSize+1) + } + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestRemoved, + NICID: 1, + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, + { + EventType: entryTestAdded, + NICID: 1, + Addr: entry.Addr, + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: 1, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, + } + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.events = nil + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } + + // Expect to find only the frequently used entry and the most recent entries. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + wantUnsortedEntries := []NeighborEntry{ + { + Addr: frequentlyUsedEntry.Addr, + LocalAddr: frequentlyUsedEntry.LocalAddr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, + }, + } + + for i := store.size() - neighborCacheSize + 1; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Fatalf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } + + // No more events should have been dispatched. + nudDisp.mu.Lock() + defer nudDisp.mu.Unlock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheConcurrent(t *testing.T) { + const concurrentProcesses = 16 + + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + storeEntries := store.entries() + for _, entry := range storeEntries { + var wg sync.WaitGroup + for r := 0; r < concurrentProcesses; r++ { + wg.Add(1) + go func(entry NeighborEntry) { + defer wg.Done() + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + } + }(entry) + } + + // Wait for all gorountines to send a request + wg.Wait() + + // Process all the requests for a single entry concurrently + clock.advance(typicalLatency) + } + + // All goroutines add in the same order and add more values than can fit in + // the cache. Our eviction strategy requires that the last entries are + // present, up to the size of the neighbor cache, and the rest are missing. + // The order of entries reported by entries() is undeterministic, so entries + // have to be sorted before comparison. + var wantUnsortedEntries []NeighborEntry + for i := store.size() - neighborCacheSize; i < store.size(); i++ { + entry, ok := store.entry(i) + if !ok { + t.Errorf("store.entry(%d) not found", i) + } + wantEntry := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) + } + + if diff := cmp.Diff(neigh.entries(), wantUnsortedEntries, entryDiffOptsWithSort()...); diff != "" { + t.Errorf("neighbor entries mismatch (-got, +want):\n%s", diff) + } +} + +func TestNeighborCacheReplace(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + // Add an entry + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + + // Verify the entry exists + e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + if doneCh != nil { + t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + } + if t.Failed() { + t.FailNow() + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + } + + // Notify of a link address change + var updatedLinkAddr tcpip.LinkAddress + { + entry, ok := store.entry(1) + if !ok { + t.Fatalf("store.entry(1) not found") + } + updatedLinkAddr = entry.LinkAddr + } + store.set(0, updatedLinkAddr) + neigh.HandleConfirmation(entry.Addr, updatedLinkAddr, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + + // Requesting the entry again should start address resolution + { + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(config.DelayFirstProbeTime + typicalLatency) + select { + case <-doneCh: + default: + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + } + } + + // Verify the entry's new link address + { + e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + clock.advance(typicalLatency) + if err != nil { + t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want = NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: updatedLinkAddr, + State: Reachable, + } + if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + } +} + +func TestNeighborCacheResolutionFailed(t *testing.T) { + config := DefaultNUDConfigurations() + + nudDisp := testNUDDispatcher{} + clock := newFakeClock() + neigh := newTestNeighborCache(&nudDisp, config, clock) + store := newTestEntryStore() + + var requestCount uint32 + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + onLinkAddressRequest: func() { + atomic.AddUint32(&requestCount, 1) + }, + } + + // First, sanity check that resolution is working + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + clock.advance(typicalLatency) + got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + } + want := NeighborEntry{ + Addr: entry.Addr, + LocalAddr: entry.LocalAddr, + LinkAddr: entry.LinkAddr, + State: Reachable, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + } + + // Verify that address resolution for an unknown address returns ErrNoLinkAddress + before := atomic.LoadUint32(&requestCount) + + entry.Addr += "2" + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } + + maxAttempts := neigh.config().MaxUnicastProbes + if got, want := atomic.LoadUint32(&requestCount)-before, maxAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } +} + +// TestNeighborCacheResolutionTimeout simulates sending MaxMulticastProbes +// probes and not retrieving a confirmation before the duration defined by +// MaxMulticastProbes * RetransmitTimer. +func TestNeighborCacheResolutionTimeout(t *testing.T) { + config := DefaultNUDConfigurations() + config.RetransmitTimer = time.Millisecond // small enough to cause timeout + + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: time.Minute, // large enough to cause timeout + } + + entry, ok := store.entry(0) + if !ok { + t.Fatalf("store.entry(0) not found") + } + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) + clock.advance(waitFor) + if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + } +} + +// TestNeighborCacheStaticResolution checks that static link addresses are +// resolved immediately and don't send resolution requests. +func TestNeighborCacheStaticResolution(t *testing.T) { + config := DefaultNUDConfigurations() + clock := newFakeClock() + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: typicalLatency, + } + + got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + if err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + } + want := NeighborEntry{ + Addr: testEntryBroadcastAddr, + LocalAddr: testEntryLocalAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, + } + if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { + t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + } +} + +func BenchmarkCacheClear(b *testing.B) { + b.StopTimer() + config := DefaultNUDConfigurations() + clock := &tcpip.StdClock{} + neigh := newTestNeighborCache(nil, config, clock) + store := newTestEntryStore() + linkRes := &testNeighborResolver{ + clock: clock, + neigh: neigh, + entries: store, + delay: 0, + } + + // Clear for every possible size of the cache + for cacheSize := 0; cacheSize < neighborCacheSize; cacheSize++ { + // Fill the neighbor cache to capacity. + for i := 0; i < cacheSize; i++ { + entry, ok := store.entry(i) + if !ok { + b.Fatalf("store.entry(%d) not found", i) + } + _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + if err != tcpip.ErrWouldBlock { + b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + } + if doneCh != nil { + <-doneCh + } + } + + b.StartTimer() + neigh.clear() + b.StopTimer() + } +} diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go new file mode 100644 index 000000000..0068cacb8 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -0,0 +1,482 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +// NeighborEntry describes a neighboring device in the local network. +type NeighborEntry struct { + Addr tcpip.Address + LocalAddr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +// NeighborState defines the state of a NeighborEntry within the Neighbor +// Unreachability Detection state machine, as per RFC 4861 section 7.3.2. +type NeighborState uint8 + +const ( + // Unknown means reachability has not been verified yet. This is the initial + // state of entries that have been created automatically by the Neighbor + // Unreachability Detection state machine. + Unknown NeighborState = iota + // Incomplete means that there is an outstanding request to resolve the + // address. + Incomplete + // Reachable means the path to the neighbor is functioning properly for both + // receive and transmit paths. + Reachable + // Stale means reachability to the neighbor is unknown, but packets are still + // able to be transmitted to the possibly stale link address. + Stale + // Delay means reachability to the neighbor is unknown and pending + // confirmation from an upper-level protocol like TCP, but packets are still + // able to be transmitted to the possibly stale link address. + Delay + // Probe means a reachability confirmation is actively being sought by + // periodically retransmitting reachability probes until a reachability + // confirmation is received, or until the max amount of probes has been sent. + Probe + // Static describes entries that have been explicitly added by the user. They + // do not expire and are not deleted until explicitly removed. + Static + // Failed means traffic should not be sent to this neighbor since attempts of + // reachability have returned inconclusive. + Failed +) + +// neighborEntry implements a neighbor entry's individual node behavior, as per +// RFC 4861 section 7.3.3. Neighbor Unreachability Detection operates in +// parallel with the sending of packets to a neighbor, necessitating the +// entry's lock to be acquired for all operations. +type neighborEntry struct { + neighborEntryEntry + + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkRes provides the functionality to send reachability probes, used in + // Neighbor Unreachability Detection. + linkRes LinkAddressResolver + + // nudState points to the Neighbor Unreachability Detection configuration. + nudState *NUDState + + // mu protects the fields below. + mu sync.RWMutex + + neigh NeighborEntry + + // wakers is a set of waiters for address resolution result. Anytime state + // transitions out of incomplete these waiters are notified. It is nil iff + // address resolution is ongoing and no clients are waiting for the result. + wakers map[*sleep.Waker]struct{} + + // done is used to allow callers to wait on address resolution. It is nil + // iff nudState is not Reachable and address resolution is not yet in + // progress. + done chan struct{} + + isRouter bool + job *tcpip.Job +} + +// newNeighborEntry creates a neighbor cache entry starting at the default +// state, Unknown. Transition out of Unknown by calling either +// `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created +// neighborEntry. +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { + return &neighborEntry{ + nic: nic, + linkRes: linkRes, + nudState: nudState, + neigh: NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + State: Unknown, + }, + } +} + +// newStaticNeighborEntry creates a neighbor cache entry starting at the Static +// state. The entry can only transition out of Static by directly calling +// `setStateLocked`. +func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + if nic.stack.nudDisp != nil { + nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + } + return &neighborEntry{ + nic: nic, + nudState: state, + neigh: NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), + }, + } +} + +// addWaker adds w to the list of wakers waiting for address resolution. +// Assumes the entry has already been appropriately locked. +func (e *neighborEntry) addWakerLocked(w *sleep.Waker) { + if w == nil { + return + } + if e.wakers == nil { + e.wakers = make(map[*sleep.Waker]struct{}) + } + e.wakers[w] = struct{}{} +} + +// notifyWakersLocked notifies those waiting for address resolution, whether it +// succeeded or failed. Assumes the entry has already been appropriately locked. +func (e *neighborEntry) notifyWakersLocked() { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if ch := e.done; ch != nil { + close(ch) + e.done = nil + } +} + +// dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has +// been added. +func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry +// has changed state or link-layer address. +func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + } +} + +// dispatchRemoveEventLocked signals to stack's NUD Dispatcher that the entry +// has been removed. +func (e *neighborEntry) dispatchRemoveEventLocked() { + if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + } +} + +// setStateLocked transitions the entry to the specified state immediately. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +// +// e.mu MUST be locked. +func (e *neighborEntry) setStateLocked(next NeighborState) { + // Cancel the previously scheduled action, if there is one. Entries in + // Unknown, Stale, or Static state do not have scheduled actions. + if timer := e.job; timer != nil { + timer.Cancel() + } + + prev := e.neigh.State + e.neigh.State = next + e.neigh.UpdatedAt = time.Now() + config := e.nudState.Config() + + switch next { + case Incomplete: + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.linkEP); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendMulticastProbe() + + case Reachable: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + }) + e.job.Schedule(e.nudState.ReachableTime()) + + case Delay: + e.job = e.nic.stack.newJob(&e.mu, func() { + e.dispatchChangeEventLocked(Probe) + e.setStateLocked(Probe) + }) + e.job.Schedule(config.DelayFirstProbeTime) + + case Probe: + var retryCounter uint32 + var sendUnicastProbe func() + + sendUnicastProbe = func() { + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.linkEP); err != nil { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + if retryCounter == config.MaxUnicastProbes { + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + sendUnicastProbe() + + case Failed: + e.notifyWakersLocked() + e.job = e.nic.stack.newJob(&e.mu, func() { + e.nic.neigh.removeEntryLocked(e) + }) + e.job.Schedule(config.UnreachableTime) + + case Unknown, Stale, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid state transition from %q to %q", prev, next)) + } +} + +// handlePacketQueuedLocked advances the state machine according to a packet +// being queued for outgoing transmission. +// +// Follows the logic defined in RFC 4861 section 7.3.3. +func (e *neighborEntry) handlePacketQueuedLocked() { + switch e.neigh.State { + case Unknown: + e.dispatchAddEventLocked(Incomplete) + e.setStateLocked(Incomplete) + + case Stale: + e.dispatchChangeEventLocked(Delay) + e.setStateLocked(Delay) + + case Incomplete, Reachable, Delay, Probe, Static, Failed: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleProbeLocked processes an incoming neighbor probe (e.g. ARP request or +// Neighbor Solicitation for ARP or NDP, respectively). +// +// Follows the logic defined in RFC 4861 section 7.2.3. +func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { + // Probes MUST be silently discarded if the target address is tentative, does + // not exist, or not bound to the NIC as per RFC 4861 section 7.2.3. These + // checks MUST be done by the NetworkEndpoint. + + switch e.neigh.State { + case Unknown, Incomplete, Failed: + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchAddEventLocked(Stale) + e.setStateLocked(Stale) + e.notifyWakersLocked() + + case Reachable, Delay, Probe: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + + case Stale: + if e.neigh.LinkAddr != remoteLinkAddr { + e.neigh.LinkAddr = remoteLinkAddr + e.dispatchChangeEventLocked(Stale) + } + + case Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleConfirmationLocked processes an incoming neighbor confirmation +// (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, respectively). +// +// Follows the state machine defined by RFC 4861 section 7.2.5. +// +// TODO(gvisor.dev/issue/2277): To protect against ARP poisoning and other +// attacks against NDP functions, Secure Neighbor Discovery (SEND) Protocol +// should be deployed where preventing access to the broadcast segment might +// not be possible. SEND uses RSA key pairs to produce Cryptographically +// Generated Addresses (CGA), as defined in RFC 3972. This ensures that the +// claimed source of an NDP message is the owner of the claimed address. +func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) { + switch e.neigh.State { + case Incomplete: + if len(linkAddr) == 0 { + // "If the link layer has addresses and no Target Link-Layer Address + // option is included, the receiving node SHOULD silently discard the + // received advertisement." - RFC 4861 section 7.2.5 + break + } + + e.neigh.LinkAddr = linkAddr + if flags.Solicited { + e.dispatchChangeEventLocked(Reachable) + e.setStateLocked(Reachable) + } else { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + e.isRouter = flags.IsRouter + e.notifyWakersLocked() + + // "Note that the Override flag is ignored if the entry is in the + // INCOMPLETE state." - RFC 4861 section 7.2.5 + + case Reachable, Stale, Delay, Probe: + sameLinkAddr := e.neigh.LinkAddr == linkAddr + + if !sameLinkAddr { + if !flags.Override { + if e.neigh.State == Reachable { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } + break + } + + e.neigh.LinkAddr = linkAddr + + if !flags.Solicited { + if e.neigh.State != Stale { + e.dispatchChangeEventLocked(Stale) + e.setStateLocked(Stale) + } else { + // Notify the LinkAddr change, even though NUD state hasn't changed. + e.dispatchChangeEventLocked(e.neigh.State) + } + break + } + } + + if flags.Solicited && (flags.Override || sameLinkAddr) { + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + } + // Set state to Reachable again to refresh timers. + e.setStateLocked(Reachable) + e.notifyWakersLocked() + } + + if e.isRouter && !flags.IsRouter { + // "In those cases where the IsRouter flag changes from TRUE to FALSE as + // a result of this update, the node MUST remove that router from the + // Default Router List and update the Destination Cache entries for all + // destinations using that neighbor as a router as specified in Section + // 7.3.3. This is needed to detect when a node that is used as a router + // stops forwarding packets due to being configured as a host." + // - RFC 4861 section 7.2.5 + e.nic.mu.Lock() + e.nic.mu.ndp.invalidateDefaultRouter(e.neigh.Addr) + e.nic.mu.Unlock() + } + e.isRouter = flags.IsRouter + + case Unknown, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} + +// handleUpperLevelConfirmationLocked processes an incoming upper-level protocol +// (e.g. TCP acknowledgements) reachability confirmation. +func (e *neighborEntry) handleUpperLevelConfirmationLocked() { + switch e.neigh.State { + case Reachable, Stale, Delay, Probe: + if e.neigh.State != Reachable { + e.dispatchChangeEventLocked(Reachable) + // Set state to Reachable again to refresh timers. + } + e.setStateLocked(Reachable) + + case Unknown, Incomplete, Failed, Static: + // Do nothing + + default: + panic(fmt.Sprintf("Invalid cache entry state: %s", e.neigh.State)) + } +} diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go new file mode 100644 index 000000000..08c9ccd25 --- /dev/null +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -0,0 +1,2770 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "math" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + entryTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + + entryTestNICID tcpip.NICID = 1 + entryTestAddr1 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01") + entryTestAddr2 = tcpip.Address("\x00\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + entryTestLinkAddr1 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x01") + entryTestLinkAddr2 = tcpip.LinkAddress("\x0a\x00\x00\x00\x00\x02") + + // entryTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, + // except where another value is explicitly used. It is chosen to match the + // MTU of loopback interfaces on Linux systems. + entryTestNetDefaultMTU = 65536 +) + +// eventDiffOpts are the options passed to cmp.Diff to compare entry events. +// The UpdatedAt field is ignored due to a lack of a deterministic method to +// predict the time that an event will be dispatched. +func eventDiffOpts() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + } +} + +// eventDiffOptsWithSort is like eventDiffOpts but also includes an option to +// sort slices of events for cases where ordering must be ignored. +func eventDiffOptsWithSort() []cmp.Option { + return []cmp.Option{ + cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + }), + } +} + +// The following unit tests exercise every state transition and verify its +// behavior with RFC 4681. +// +// | From | To | Cause | Action | Event | +// | ========== | ========== | ========================================== | =============== | ======= | +// | Unknown | Unknown | Confirmation w/ unknown address | | Added | +// | Unknown | Incomplete | Packet queued to unknown address | Send probe | Added | +// | Unknown | Stale | Probe w/ unknown address | | Added | +// | Incomplete | Incomplete | Retransmit timer expired | Send probe | Changed | +// | Incomplete | Reachable | Solicited confirmation | Notify wakers | Changed | +// | Incomplete | Stale | Unsolicited confirmation | Notify wakers | Changed | +// | Incomplete | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Reachable | Reachable | Confirmation w/ different isRouter flag | Update IsRouter | | +// | Reachable | Stale | Reachable timer expired | | Changed | +// | Reachable | Stale | Probe or confirmation w/ different address | | Changed | +// | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Override confirmation | Update LinkAddr | Changed | +// | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | +// | Stale | Delay | Packet sent | | Changed | +// | Delay | Reachable | Upper-layer confirmation | | Changed | +// | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Stale | Probe or confirmation w/ different address | | Changed | +// | Delay | Probe | Delay timer expired | Send probe | Changed | +// | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Stale | Probe or confirmation w/ different address | | Changed | +// | Probe | Probe | Retransmit timer expired | Send probe | Changed | +// | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | +// | Failed | | Unreachability timer expired | Delete entry | | + +type testEntryEventType uint8 + +const ( + entryTestAdded testEntryEventType = iota + entryTestChanged + entryTestRemoved +) + +func (t testEntryEventType) String() string { + switch t { + case entryTestAdded: + return "add" + case entryTestChanged: + return "change" + case entryTestRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +// Fields are exported for use with cmp.Diff. +type testEntryEventInfo struct { + EventType testEntryEventType + NICID tcpip.NICID + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAt time.Time +} + +func (e testEntryEventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) +} + +// testNUDDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type testNUDDispatcher struct { + mu sync.Mutex + events []testEntryEventInfo +} + +var _ NUDDispatcher = (*testNUDDispatcher)(nil) + +func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { + d.mu.Lock() + defer d.mu.Unlock() + d.events = append(d.events, e) +} + +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestAdded, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestChanged, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { + d.queueEvent(testEntryEventInfo{ + EventType: entryTestRemoved, + NICID: nicID, + Addr: addr, + LinkAddr: linkAddr, + State: state, + UpdatedAt: updatedAt, + }) +} + +type entryTestLinkResolver struct { + mu sync.Mutex + probes []entryTestProbeInfo +} + +var _ LinkAddressResolver = (*entryTestLinkResolver)(nil) + +type entryTestProbeInfo struct { + RemoteAddress tcpip.Address + RemoteLinkAddress tcpip.LinkAddress + LocalAddress tcpip.Address +} + +func (p entryTestProbeInfo) String() string { + return fmt.Sprintf("probe with RemoteAddress=%q, RemoteLinkAddress=%q, LocalAddress=%q", p.RemoteAddress, p.RemoteLinkAddress, p.LocalAddress) +} + +// LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts +// to the local network if linkAddr is the zero value. +func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { + p := entryTestProbeInfo{ + RemoteAddress: addr, + RemoteLinkAddress: linkAddr, + LocalAddress: localAddr, + } + r.mu.Lock() + defer r.mu.Unlock() + r.probes = append(r.probes, p) + return nil +} + +// ResolveStaticAddress attempts to resolve address without sending requests. +// It either resolves the name immediately or returns the empty LinkAddress. +func (r *entryTestLinkResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +// LinkAddressProtocol returns the network protocol of the addresses this +// resolver can resolve. +func (r *entryTestLinkResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return entryTestNetNumber +} + +func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *entryTestLinkResolver, *fakeClock) { + clock := newFakeClock() + disp := testNUDDispatcher{} + nic := NIC{ + id: entryTestNICID, + linkEP: nil, // entryTestLinkResolver doesn't use a LinkEndpoint + stack: &Stack{ + clock: clock, + nudDisp: &disp, + }, + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + nudState := NewNUDState(c, rng) + linkRes := entryTestLinkResolver{} + entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) + + // Stub out ndpState to verify modification of default routers. + nic.mu.ndp = ndpState{ + nic: &nic, + defaultRouters: make(map[tcpip.Address]defaultRouterState), + } + + // Stub out the neighbor cache to verify deletion from the cache. + nic.neigh = &neighborCache{ + nic: &nic, + state: nudState, + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + nic.neigh.cache[entryTestAddr1] = entry + + return entry, &disp, &linkRes, clock +} + +// TestEntryInitiallyUnknown verifies that the state of a newly created +// neighborEntry is Unknown. +func TestEntryInitiallyUnknown(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Unknown; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(time.Hour) + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + // No events should have been dispatched. + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryUnknownToIncomplete(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + { + nudDisp.mu.Lock() + diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...) + nudDisp.mu.Unlock() + if diff != "" { + t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + } +} + +func TestEntryUnknownToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + // No probes should have been sent. + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + updatedAt := e.neigh.UpdatedAt + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should remain the same during address resolution. + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.RetransmitTimer) + + // UpdatedAt should change after failing address resolution. Timing out after + // sending the last probe transitions the entry to Failed. + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + clock.advance(c.RetransmitTimer) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) + } + e.mu.Unlock() +} + +func TestEntryIncompleteToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +// TestEntryAddsAndClearsWakers verifies that wakers are added when +// addWakerLocked is called and cleared when address resolution finishes. In +// this case, address resolution will finish when transitioning from Incomplete +// to Reachable. +func TestEntryAddsAndClearsWakers(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got := e.wakers; got != nil { + t.Errorf("got e.wakers = %v, want = nil", got) + } + e.addWakerLocked(&w) + if got, want := w.IsAsserted(), false; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + if e.wakers == nil { + t.Error("expected e.wakers to be non-nil") + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.wakers != nil { + t.Errorf("got e.wakers = %v, want = nil", e.wakers) + } + if got, want := w.IsAsserted(), true; got != want { + t.Errorf("waker.IsAsserted() = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + if diff := cmp.Diff(linkRes.probes, wantProbes); diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + linkRes.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToStale(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryIncompleteToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Incomplete; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + waitFor := c.RetransmitTimer * time.Duration(c.MaxMulticastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The Incomplete-to-Incomplete state transition is tested here by + // verifying that 3 reachability probes were sent. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +type testLocker struct{} + +var _ sync.Locker = (*testLocker)(nil) + +func (*testLocker) Lock() {} +func (*testLocker) Unlock() {} + +func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.isRouter, true; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + e.nic.mu.ndp.defaultRouters[entryTestAddr1] = defaultRouterState{ + invalidationJob: e.nic.stack.newJob(&testLocker{}, func() {}), + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.isRouter, false; got != want { + t.Errorf("got e.isRouter = %t, want = %t", got, want) + } + if _, ok := e.nic.mu.ndp.defaultRouters[entryTestAddr1]; ok { + t.Errorf("unexpected defaultRouter for %s", entryTestAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryReachableToStaleWhenTimeout(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr1) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaleToDelay(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleUpperLevelConfirmationLocked() + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryDelayToProbe(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if got, want := e.neigh.State, Delay; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleProbeLocked(entryTestLinkAddr2) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Stale; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + +func TestEntryProbeToFailed(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + e.mu.Lock() + if got, want := e.neigh.State, Failed; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.mu.Unlock() +} + +func TestEntryFailedGetsDeleted(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 3 + c.MaxUnicastProbes = 3 + e, nudDisp, linkRes, clock := entryTestSetup(c) + + // Verify the cache contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; !ok { + t.Errorf("expected entry %q to exist in the neighbor cache", entryTestAddr1) + } + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime + clock.advance(waitFor) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The next three probe are caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestRemoved, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() + + // Verify the cache no longer contains the entry. + if _, ok := e.nic.neigh.cache[entryTestAddr1]; ok { + t.Errorf("entry %q should have been deleted from the neighbor cache", entryTestAddr1) + } +} diff --git a/pkg/tcpip/stack/neighborstate_string.go b/pkg/tcpip/stack/neighborstate_string.go new file mode 100644 index 000000000..aa7311ec6 --- /dev/null +++ b/pkg/tcpip/stack/neighborstate_string.go @@ -0,0 +1,44 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by "stringer -type NeighborState"; DO NOT EDIT. + +package stack + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[Incomplete-1] + _ = x[Reachable-2] + _ = x[Stale-3] + _ = x[Delay-4] + _ = x[Probe-5] + _ = x[Static-6] + _ = x[Failed-7] +} + +const _NeighborState_name = "UnknownIncompleteReachableStaleDelayProbeStaticFailed" + +var _NeighborState_index = [...]uint8{0, 7, 17, 26, 31, 36, 41, 47, 53} + +func (i NeighborState) String() string { + if i >= NeighborState(len(_NeighborState_index)-1) { + return "NeighborState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NeighborState_name[_NeighborState_index[i]:_NeighborState_index[i+1]] +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index afb7dfeaf..f21066fce 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math/rand" "reflect" "sort" "strings" @@ -45,6 +46,7 @@ type NIC struct { context NICContext stats NICStats + neigh *neighborCache mu struct { sync.RWMutex @@ -141,6 +143,16 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{} } + // Check for Neighbor Unreachability Detection support. + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) + nic.neigh = &neighborCache{ + nic: nic, + state: NewNUDState(stack.nudConfigs, rng), + cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), + } + } + nic.linkEP.Attach(nic) return nic @@ -181,7 +193,7 @@ func (n *NIC) disableLocked() *tcpip.Error { return nil } - // TODO(b/147015577): Should Routes that are currently bound to n be + // TODO(gvisor.dev/issue/1491): Should Routes that are currently bound to n be // invalidated? Currently, Routes will continue to work when a NIC is enabled // again, and applications may not know that the underlying NIC was ever // disabled. @@ -1200,15 +1212,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // Are any packet sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] - // Check whether there are packet sockets listening for every protocol. - // If we received a packet with protocol EthernetProtocolAll, then the - // previous for loop will have handled it. - if protocol != header.EthernetProtocolAll { - packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) - } + // Add any other packet sockets that maybe listening for all protocols. + packetEPs = append(packetEPs, n.mu.packetEPs[header.EthernetProtocolAll]...) n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, pkt.Clone()) + p := pkt.Clone() + p.PktType = tcpip.PacketHost + ep.HandlePacket(n.id, local, protocol, p) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { @@ -1311,6 +1321,24 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } } +// DeliverOutboundPacket implements NetworkDispatcher.DeliverOutboundPacket. +func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + n.mu.RLock() + // We do not deliver to protocol specific packet endpoints as on Linux + // only ETH_P_ALL endpoints get outbound packets. + // Add any other packet sockets that maybe listening for all protocols. + packetEPs := n.mu.packetEPs[header.EthernetProtocolAll] + n.mu.RUnlock() + for _, ep := range packetEPs { + p := pkt.Clone() + p.PktType = tcpip.PacketOutgoing + // Add the link layer header as outgoing packets are intercepted + // before the link layer header is created. + n.linkEP.AddHeader(local, remote, protocol, p) + ep.HandlePacket(n.id, local, protocol, p) + } +} + func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { // TODO(b/143425874) Decrease the TTL field in forwarded packets. // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this @@ -1358,16 +1386,19 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // TransportHeader is nil only when pkt is an ICMP packet or was reassembled // from fragments. if pkt.TransportHeader == nil { - // TODO(gvisor.dev/issue/170): ICMP packets don't have their - // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a // full explanation. if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber { + // ICMP packets may be longer, but until icmp.Parse is implemented, here + // we parse it using the minimum size. transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize()) if !ok { n.stack.stats.MalformedRcvdPackets.Increment() return } pkt.TransportHeader = transHeader + pkt.Data.TrimFront(len(pkt.TransportHeader)) } else { // This is either a bad packet or was re-assembled from fragments. transProto.Parse(pkt) @@ -1521,6 +1552,27 @@ func (n *NIC) setNDPConfigs(c NDPConfigurations) { n.mu.Unlock() } +// NUDConfigs gets the NUD configurations for n. +func (n *NIC) NUDConfigs() (NUDConfigurations, *tcpip.Error) { + if n.neigh == nil { + return NUDConfigurations{}, tcpip.ErrNotSupported + } + return n.neigh.config(), nil +} + +// setNUDConfigs sets the NUD configurations for n. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (n *NIC) setNUDConfigs(c NUDConfigurations) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + c.resetInvalidFields() + n.neigh.setConfig(c) + return nil +} + // handleNDPRA handles an NDP Router Advertisement message that arrived on n. func (n *NIC) handleNDPRA(ip tcpip.Address, ra header.NDPRouterAdvert) { n.mu.Lock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 31f865260..a70792b50 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -84,6 +84,16 @@ func (e *testLinkEndpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { return tcpip.ErrNotSupported } +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (*testLinkEndpoint) ARPHardwareType() header.ARPHardwareType { + panic("not implemented") +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *testLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) { + panic("not implemented") +} + var _ NetworkEndpoint = (*testIPv6Endpoint)(nil) // An IPv6 NetworkEndpoint that throws away outgoing packets. @@ -233,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go new file mode 100644 index 000000000..f848d50ad --- /dev/null +++ b/pkg/tcpip/stack/nud.go @@ -0,0 +1,466 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "math" + "sync" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +const ( + // defaultBaseReachableTime is the default base duration for computing the + // random reachable time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of a uniformly distributed random value between the minimum and + // maximum random factors, multiplied by the base reachable time. Using a + // random component eliminates the possibility that Neighbor Unreachability + // Detection messages will synchronize with each other. + // + // Default taken from REACHABLE_TIME of RFC 4861 section 10. + defaultBaseReachableTime = 30 * time.Second + + // minimumBaseReachableTime is the minimum base duration for computing the + // random reachable time. + // + // Minimum = 1ms + minimumBaseReachableTime = time.Millisecond + + // defaultMinRandomFactor is the default minimum value of the random factor + // used for computing reachable time. + // + // Default taken from MIN_RANDOM_FACTOR of RFC 4861 section 10. + defaultMinRandomFactor = 0.5 + + // defaultMaxRandomFactor is the default maximum value of the random factor + // used for computing reachable time. + // + // The default value depends on the value of MinRandomFactor. + // If MinRandomFactor is less than MAX_RANDOM_FACTOR of RFC 4861 section 10, + // the value from the RFC will be used; otherwise, the default is + // MinRandomFactor multiplied by three. + defaultMaxRandomFactor = 1.5 + + // defaultRetransmitTimer is the default amount of time to wait between + // sending reachability probes. + // + // Default taken from RETRANS_TIMER of RFC 4861 section 10. + defaultRetransmitTimer = time.Second + + // minimumRetransmitTimer is the minimum amount of time to wait between + // sending reachability probes. + // + // Note, RFC 4861 does not impose a minimum Retransmit Timer, but we do here + // to make sure the messages are not sent all at once. We also come to this + // value because in the RetransmitTimer field of a Router Advertisement, a + // value of 0 means unspecified, so the smallest valid value is 1. Note, the + // unit of the RetransmitTimer field in the Router Advertisement is + // milliseconds. + minimumRetransmitTimer = time.Millisecond + + // defaultDelayFirstProbeTime is the default duration to wait for a + // non-Neighbor-Discovery related protocol to reconfirm reachability after + // entering the DELAY state. After this time, a reachability probe will be + // sent and the entry will transition to the PROBE state. + // + // Default taken from DELAY_FIRST_PROBE_TIME of RFC 4861 section 10. + defaultDelayFirstProbeTime = 5 * time.Second + + // defaultMaxMulticastProbes is the default number of reachabililty probes + // to send before concluding negative reachability and deleting the neighbor + // entry from the INCOMPLETE state. + // + // Default taken from MAX_MULTICAST_SOLICIT of RFC 4861 section 10. + defaultMaxMulticastProbes = 3 + + // defaultMaxUnicastProbes is the default number of reachability probes to + // send before concluding retransmission from within the PROBE state should + // cease and the entry SHOULD be deleted. + // + // Default taken from MAX_UNICASE_SOLICIT of RFC 4861 section 10. + defaultMaxUnicastProbes = 3 + + // defaultMaxAnycastDelayTime is the default time in which the stack SHOULD + // delay sending a response for a random time between 0 and this time, if the + // target address is an anycast address. + // + // Default taken from MAX_ANYCAST_DELAY_TIME of RFC 4861 section 10. + defaultMaxAnycastDelayTime = time.Second + + // defaultMaxReachbilityConfirmations is the default amount of unsolicited + // reachability confirmation messages a node MAY send to all-node multicast + // address when it determines its link-layer address has changed. + // + // Default taken from MAX_NEIGHBOR_ADVERTISEMENT of RFC 4861 section 10. + defaultMaxReachbilityConfirmations = 3 + + // defaultUnreachableTime is the default duration for how long an entry will + // remain in the FAILED state before being removed from the neighbor cache. + // + // Note, there is no equivalent protocol constant defined in RFC 4861. It + // leaves the specifics of any garbage collection mechanism up to the + // implementation. + defaultUnreachableTime = 5 * time.Second +) + +// NUDDispatcher is the interface integrators of netstack must implement to +// receive and handle NUD related events. +type NUDDispatcher interface { + // OnNeighborAdded will be called when a new entry is added to a NIC's (with + // ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) + // neighbor table changes state and/or link address. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + + // OnNeighborRemoved will be called when an entry is removed from a NIC's + // (with ID nicID) neighbor table. + // + // This function is permitted to block indefinitely without interfering with + // the stack's operation. + // + // May be called concurrently. + OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) +} + +// ReachabilityConfirmationFlags describes the flags used within a reachability +// confirmation (e.g. ARP reply or Neighbor Advertisement for ARP or NDP, +// respectively). +type ReachabilityConfirmationFlags struct { + // Solicited indicates that the advertisement was sent in response to a + // reachability probe. + Solicited bool + + // Override indicates that the reachability confirmation should override an + // existing neighbor cache entry and update the cached link-layer address. + // When Override is not set the confirmation will not update a cached + // link-layer address, but will update an existing neighbor cache entry for + // which no link-layer address is known. + Override bool + + // IsRouter indicates that the sender is a router. + IsRouter bool +} + +// NUDHandler communicates external events to the Neighbor Unreachability +// Detection state machine, which is implemented per-interface. This is used by +// network endpoints to inform the Neighbor Cache of probes and confirmations. +type NUDHandler interface { + // HandleProbe processes an incoming neighbor probe (e.g. ARP request or + // Neighbor Solicitation for ARP or NDP, respectively). Validation of the + // probe needs to be performed before calling this function since the + // Neighbor Cache doesn't have access to view the NIC's assigned addresses. + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + + // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP + // reply or Neighbor Advertisement for ARP or NDP, respectively). + HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags ReachabilityConfirmationFlags) + + // HandleUpperLevelConfirmation processes an incoming upper-level protocol + // (e.g. TCP acknowledgements) reachability confirmation. + HandleUpperLevelConfirmation(addr tcpip.Address) +} + +// NUDConfigurations is the NUD configurations for the netstack. This is used +// by the neighbor cache to operate the NUD state machine on each device in the +// local network. +type NUDConfigurations struct { + // BaseReachableTime is the base duration for computing the random reachable + // time. + // + // Reachable time is the duration for which a neighbor is considered + // reachable after a positive reachability confirmation is received. It is a + // function of uniformly distributed random value between minRandomFactor and + // maxRandomFactor multiplied by baseReachableTime. Using a random component + // eliminates the possibility that Neighbor Unreachability Detection messages + // will synchronize with each other. + // + // After this time, a neighbor entry will transition from REACHABLE to STALE + // state. + // + // Must be greater than 0. + BaseReachableTime time.Duration + + // LearnBaseReachableTime enables learning BaseReachableTime during runtime + // from the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2240): Implement this NUD configuration option. + LearnBaseReachableTime bool + + // MinRandomFactor is the minimum value of the random factor used for + // computing reachable time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be greater than 0. + MinRandomFactor float32 + + // MaxRandomFactor is the maximum value of the random factor used for + // computing reachabile time. + // + // See BaseReachbleTime for more information on computing the reachable time. + // + // Must be great than or equal to MinRandomFactor. + MaxRandomFactor float32 + + // RetransmitTimer is the duration between retransmission of reachability + // probes in the PROBE state. + RetransmitTimer time.Duration + + // LearnRetransmitTimer enables learning RetransmitTimer during runtime from + // the neighbor discovery protocol, if supported. + // + // TODO(gvisor.dev/issue/2241): Implement this NUD configuration option. + LearnRetransmitTimer bool + + // DelayFirstProbeTime is the duration to wait for a non-Neighbor-Discovery + // related protocol to reconfirm reachability after entering the DELAY state. + // After this time, a reachability probe will be sent and the entry will + // transition to the PROBE state. + // + // Must be greater than 0. + DelayFirstProbeTime time.Duration + + // MaxMulticastProbes is the number of reachability probes to send before + // concluding negative reachability and deleting the neighbor entry from the + // INCOMPLETE state. + // + // Must be greater than 0. + MaxMulticastProbes uint32 + + // MaxUnicastProbes is the number of reachability probes to send before + // concluding retransmission from within the PROBE state should cease and + // entry SHOULD be deleted. + // + // Must be greater than 0. + MaxUnicastProbes uint32 + + // MaxAnycastDelayTime is the time in which the stack SHOULD delay sending a + // response for a random time between 0 and this time, if the target address + // is an anycast address. + // + // TODO(gvisor.dev/issue/2242): Use this option when sending solicited + // neighbor confirmations to anycast addresses and proxying neighbor + // confirmations. + MaxAnycastDelayTime time.Duration + + // MaxReachabilityConfirmations is the number of unsolicited reachability + // confirmation messages a node MAY send to all-node multicast address when + // it determines its link-layer address has changed. + // + // TODO(gvisor.dev/issue/2246): Discuss if implementation of this NUD + // configuration option is necessary. + MaxReachabilityConfirmations uint32 + + // UnreachableTime describes how long an entry will remain in the FAILED + // state before being removed from the neighbor cache. + UnreachableTime time.Duration +} + +// DefaultNUDConfigurations returns a NUDConfigurations populated with default +// values defined by RFC 4861 section 10. +func DefaultNUDConfigurations() NUDConfigurations { + return NUDConfigurations{ + BaseReachableTime: defaultBaseReachableTime, + LearnBaseReachableTime: true, + MinRandomFactor: defaultMinRandomFactor, + MaxRandomFactor: defaultMaxRandomFactor, + RetransmitTimer: defaultRetransmitTimer, + LearnRetransmitTimer: true, + DelayFirstProbeTime: defaultDelayFirstProbeTime, + MaxMulticastProbes: defaultMaxMulticastProbes, + MaxUnicastProbes: defaultMaxUnicastProbes, + MaxAnycastDelayTime: defaultMaxAnycastDelayTime, + MaxReachabilityConfirmations: defaultMaxReachbilityConfirmations, + UnreachableTime: defaultUnreachableTime, + } +} + +// resetInvalidFields modifies an invalid NDPConfigurations with valid values. +// If invalid values are present in c, the corresponding default values will be +// used instead. This is needed to check, and conditionally fix, user-specified +// NUDConfigurations. +func (c *NUDConfigurations) resetInvalidFields() { + if c.BaseReachableTime < minimumBaseReachableTime { + c.BaseReachableTime = defaultBaseReachableTime + } + if c.MinRandomFactor <= 0 { + c.MinRandomFactor = defaultMinRandomFactor + } + if c.MaxRandomFactor < c.MinRandomFactor { + c.MaxRandomFactor = calcMaxRandomFactor(c.MinRandomFactor) + } + if c.RetransmitTimer < minimumRetransmitTimer { + c.RetransmitTimer = defaultRetransmitTimer + } + if c.DelayFirstProbeTime == 0 { + c.DelayFirstProbeTime = defaultDelayFirstProbeTime + } + if c.MaxMulticastProbes == 0 { + c.MaxMulticastProbes = defaultMaxMulticastProbes + } + if c.MaxUnicastProbes == 0 { + c.MaxUnicastProbes = defaultMaxUnicastProbes + } + if c.UnreachableTime == 0 { + c.UnreachableTime = defaultUnreachableTime + } +} + +// calcMaxRandomFactor calculates the maximum value of the random factor used +// for computing reachable time. This function is necessary for when the +// default specified in RFC 4861 section 10 is less than the current +// MinRandomFactor. +// +// Assumes minRandomFactor is positive since validation of the minimum value +// should come before the validation of the maximum. +func calcMaxRandomFactor(minRandomFactor float32) float32 { + if minRandomFactor > defaultMaxRandomFactor { + return minRandomFactor * 3 + } + return defaultMaxRandomFactor +} + +// A Rand is a source of random numbers. +type Rand interface { + // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). + Float32() float32 +} + +// NUDState stores states needed for calculating reachable time. +type NUDState struct { + rng Rand + + // mu protects the fields below. + // + // It is necessary for NUDState to handle its own locking since neighbor + // entries may access the NUD state from within the goroutine spawned by + // time.AfterFunc(). This goroutine may run concurrently with the main + // process for controlling the neighbor cache and would otherwise introduce + // race conditions if NUDState was not locked properly. + mu sync.RWMutex + + config NUDConfigurations + + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration + + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 +} + +// NewNUDState returns new NUDState using c as configuration and the specified +// random number generator for use in recomputing ReachableTime. +func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { + s := &NUDState{ + rng: rng, + } + s.config = c + return s +} + +// Config returns the NUD configuration. +func (s *NUDState) Config() NUDConfigurations { + s.mu.RLock() + defer s.mu.RUnlock() + return s.config +} + +// SetConfig replaces the existing NUD configurations with c. +func (s *NUDState) SetConfig(c NUDConfigurations) { + s.mu.Lock() + defer s.mu.Unlock() + s.config = c +} + +// ReachableTime returns the duration to wait for a REACHABLE entry to +// transition into STALE after inactivity. This value is recalculated for new +// values of BaseReachableTime, MinRandomFactor, and MaxRandomFactor using the +// algorithm defined in RFC 4861 section 6.3.2. +func (s *NUDState) ReachableTime() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + + if time.Now().After(s.expiration) || + s.config.BaseReachableTime != s.prevBaseReachableTime || + s.config.MinRandomFactor != s.prevMinRandomFactor || + s.config.MaxRandomFactor != s.prevMaxRandomFactor { + return s.recomputeReachableTimeLocked() + } + return s.reachableTime +} + +// recomputeReachableTimeLocked forces a recalculation of ReachableTime using +// the algorithm defined in RFC 4861 section 6.3.2. +// +// This SHOULD automatically be invoked during certain situations, as per +// RFC 4861 section 6.3.4: +// +// If the received Reachable Time value is non-zero, the host SHOULD set its +// BaseReachableTime variable to the received value. If the new value +// differs from the previous value, the host SHOULD re-compute a new random +// ReachableTime value. ReachableTime is computed as a uniformly +// distributed random value between MIN_RANDOM_FACTOR and MAX_RANDOM_FACTOR +// times the BaseReachableTime. Using a random component eliminates the +// possibility that Neighbor Unreachability Detection messages will +// synchronize with each other. +// +// In most cases, the advertised Reachable Time value will be the same in +// consecutive Router Advertisements, and a host's BaseReachableTime rarely +// changes. In such cases, an implementation SHOULD ensure that a new +// random value gets re-computed at least once every few hours. +// +// s.mu MUST be locked for writing. +func (s *NUDState) recomputeReachableTimeLocked() time.Duration { + s.prevBaseReachableTime = s.config.BaseReachableTime + s.prevMinRandomFactor = s.config.MinRandomFactor + s.prevMaxRandomFactor = s.config.MaxRandomFactor + + randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + + // Check for overflow, given that minRandomFactor and maxRandomFactor are + // guaranteed to be positive numbers. + if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { + s.reachableTime = time.Duration(math.MaxInt64) + } else if randomFactor == 1 { + // Avoid loss of precision when a large base reachable time is used. + s.reachableTime = s.config.BaseReachableTime + } else { + reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) + s.reachableTime = time.Duration(reachableTime) + } + + s.expiration = time.Now().Add(2 * time.Hour) + return s.reachableTime +} diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go new file mode 100644 index 000000000..2494ee610 --- /dev/null +++ b/pkg/tcpip/stack/nud_test.go @@ -0,0 +1,795 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +const ( + defaultBaseReachableTime = 30 * time.Second + minimumBaseReachableTime = time.Millisecond + defaultMinRandomFactor = 0.5 + defaultMaxRandomFactor = 1.5 + defaultRetransmitTimer = time.Second + minimumRetransmitTimer = time.Millisecond + defaultDelayFirstProbeTime = 5 * time.Second + defaultMaxMulticastProbes = 3 + defaultMaxUnicastProbes = 3 + defaultMaxAnycastDelayTime = time.Second + defaultMaxReachbilityConfirmations = 3 + defaultUnreachableTime = 5 * time.Second + + defaultFakeRandomNum = 0.5 +) + +// fakeRand is a deterministic random number generator. +type fakeRand struct { + num float32 +} + +var _ stack.Rand = (*fakeRand)(nil) + +func (f *fakeRand) Float32() float32 { + return f.num +} + +// TestSetNUDConfigurationFailsForBadNICID tests to make sure we get an error if +// we attempt to update NUD configurations using an invalid NICID. +func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + }) + + // No NIC with ID 1 yet. + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(1, config); err != tcpip.ErrUnknownNICID { + t.Fatalf("got s.SetNDPConfigurations(1, %+v) = %v, want = %s", config, err, tcpip.ErrUnknownNICID) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to retrieve NUD configurations when the +// stack doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if _, err := s.NUDConfigurations(nicID); err != tcpip.ErrNotSupported { + t.Fatalf("got s.NDPConfigurations(%d) = %v, want = %s", nicID, err, tcpip.ErrNotSupported) + } +} + +// TestNUDConfigurationFailsForNotSupported tests to make sure we get a +// NotSupported error if we attempt to set NUD configurations when the stack +// doesn't support NUD. +// +// The stack will report to not support NUD if a neighbor cache for a given NIC +// is not allocated. The networking stack will only allocate neighbor caches if +// a protocol providing link address resolution is specified (e.g. ARP, IPv6). +func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + config := stack.NUDConfigurations{} + if err := s.SetNUDConfigurations(nicID, config); err != tcpip.ErrNotSupported { + t.Fatalf("got s.SetNDPConfigurations(%d, %+v) = %v, want = %s", nicID, config, err, tcpip.ErrNotSupported) + } +} + +// TestDefaultNUDConfigurationIsValid verifies that calling +// resetInvalidFields() on the result of DefaultNUDConfigurations() does not +// change anything. DefaultNUDConfigurations() should return a valid +// NUDConfigurations. +func TestDefaultNUDConfigurations(t *testing.T) { + const nicID = 1 + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The networking + // stack will only allocate neighbor caches if a protocol providing link + // address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: stack.DefaultNUDConfigurations(), + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + c, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got, want := c, stack.DefaultNUDConfigurations(); got != want { + t.Errorf("got stack.NUDConfigurations(%d) = %+v, want = %+v", nicID, got, want) + } +} + +func TestNUDConfigurationsBaseReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + baseReachableTime: 0, + want: defaultBaseReachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + baseReachableTime: time.Millisecond, + want: time.Millisecond, + }, + { + name: "MoreThanDefaultBaseReachableTime", + baseReachableTime: 2 * defaultBaseReachableTime, + want: 2 * defaultBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = test.baseReachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.BaseReachableTime; got != test.want { + t.Errorf("got BaseReachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMinRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: -1, + want: defaultMinRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: 0, + want: defaultMinRandomFactor, + }, + // Valid cases + { + name: "MoreThanZero", + minRandomFactor: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MinRandomFactor; got != test.want { + t.Errorf("got MinRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { + tests := []struct { + name string + minRandomFactor float32 + maxRandomFactor float32 + want float32 + }{ + // Invalid cases + { + name: "LessThanZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: -1, + want: defaultMaxRandomFactor, + }, + { + name: "EqualToZero", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: 0, + want: defaultMaxRandomFactor, + }, + { + name: "LessThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 0.99, + want: defaultMaxRandomFactor, + }, + { + name: "MoreThanMinRandomFactorWhenMinRandomFactorIsLargerThanMaxRandomFactorDefault", + minRandomFactor: defaultMaxRandomFactor * 2, + maxRandomFactor: defaultMaxRandomFactor, + want: defaultMaxRandomFactor * 6, + }, + // Valid cases + { + name: "EqualToMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor, + want: defaultMinRandomFactor, + }, + { + name: "MoreThanMinRandomFactor", + minRandomFactor: defaultMinRandomFactor, + maxRandomFactor: defaultMinRandomFactor * 1.1, + want: defaultMinRandomFactor * 1.1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxRandomFactor; got != test.want { + t.Errorf("got MaxRandomFactor = %f, want = %f", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsRetransmitTimer(t *testing.T) { + tests := []struct { + name string + retransmitTimer time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + retransmitTimer: 0, + want: defaultRetransmitTimer, + }, + { + name: "LessThanMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer - time.Nanosecond, + want: defaultRetransmitTimer, + }, + // Valid cases + { + name: "EqualToMinimumRetransmitTimer", + retransmitTimer: minimumRetransmitTimer, + want: minimumBaseReachableTime, + }, + { + name: "LargetThanMinimumRetransmitTimer", + retransmitTimer: 2 * minimumBaseReachableTime, + want: 2 * minimumBaseReachableTime, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.RetransmitTimer = test.retransmitTimer + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.RetransmitTimer; got != test.want { + t.Errorf("got RetransmitTimer = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { + tests := []struct { + name string + delayFirstProbeTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + delayFirstProbeTime: 0, + want: defaultDelayFirstProbeTime, + }, + // Valid cases + { + name: "MoreThanZero", + delayFirstProbeTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.DelayFirstProbeTime = test.delayFirstProbeTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.DelayFirstProbeTime; got != test.want { + t.Errorf("got DelayFirstProbeTime = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { + tests := []struct { + name string + maxMulticastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxMulticastProbes: 0, + want: defaultMaxMulticastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxMulticastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxMulticastProbes = test.maxMulticastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxMulticastProbes; got != test.want { + t.Errorf("got MaxMulticastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { + tests := []struct { + name string + maxUnicastProbes uint32 + want uint32 + }{ + // Invalid cases + { + name: "EqualToZero", + maxUnicastProbes: 0, + want: defaultMaxUnicastProbes, + }, + // Valid cases + { + name: "MoreThanZero", + maxUnicastProbes: 1, + want: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.MaxUnicastProbes = test.maxUnicastProbes + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.MaxUnicastProbes; got != test.want { + t.Errorf("got MaxUnicastProbes = %q, want = %q", got, test.want) + } + }) + } +} + +func TestNUDConfigurationsUnreachableTime(t *testing.T) { + tests := []struct { + name string + unreachableTime time.Duration + want time.Duration + }{ + // Invalid cases + { + name: "EqualToZero", + unreachableTime: 0, + want: defaultUnreachableTime, + }, + // Valid cases + { + name: "MoreThanZero", + unreachableTime: time.Millisecond, + want: time.Millisecond, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const nicID = 1 + + c := stack.DefaultNUDConfigurations() + c.UnreachableTime = test.unreachableTime + + e := channel.New(0, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + // A neighbor cache is required to store NUDConfigurations. The + // networking stack will only allocate neighbor caches if a protocol + // providing link address resolution is specified (e.g. ARP or IPv6). + NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + NUDConfigs: c, + }) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + sc, err := s.NUDConfigurations(nicID) + if err != nil { + t.Fatalf("got stack.NUDConfigurations(%d) = %s", nicID, err) + } + if got := sc.UnreachableTime; got != test.want { + t.Errorf("got UnreachableTime = %q, want = %q", got, test.want) + } + }) + } +} + +// TestNUDStateReachableTime verifies the correctness of the ReachableTime +// computation. +func TestNUDStateReachableTime(t *testing.T) { + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "AllZeros", + baseReachableTime: 0, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMaxRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 0, + want: 0, + }, + { + name: "ZeroMinRandomFactor", + baseReachableTime: time.Second, + minRandomFactor: 0, + maxRandomFactor: 1, + want: time.Duration(defaultFakeRandomNum * float32(time.Second)), + }, + { + name: "FractionalRandomFactor", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 0.001, + maxRandomFactor: 0.002, + want: time.Duration((0.001 + (0.001 * defaultFakeRandomNum)) * float32(math.MaxInt64)), + }, + { + name: "MinAndMaxRandomFactorsEqual", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Second, + }, + { + name: "MinAndMaxRandomFactorsDifferent", + baseReachableTime: time.Second, + minRandomFactor: 1, + maxRandomFactor: 2, + want: time.Duration((1.0 + defaultFakeRandomNum) * float32(time.Second)), + }, + { + name: "MaxInt64", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1, + maxRandomFactor: 1, + want: time.Duration(math.MaxInt64), + }, + { + name: "Overflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 1.5, + maxRandomFactor: 1.5, + want: time.Duration(math.MaxInt64), + }, + { + name: "DoubleOverflow", + baseReachableTime: time.Duration(math.MaxInt64), + minRandomFactor: 2.5, + maxRandomFactor: 2.5, + want: time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.NUDConfigurations{ + BaseReachableTime: test.baseReachableTime, + MinRandomFactor: test.minRandomFactor, + MaxRandomFactor: test.maxRandomFactor, + } + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} + +// TestNUDStateRecomputeReachableTime exercises the ReachableTime function +// twice to verify recomputation of reachable time when the min random factor, +// max random factor, or base reachable time changes. +func TestNUDStateRecomputeReachableTime(t *testing.T) { + const defaultBase = time.Second + const defaultMin = 2.0 * defaultMaxRandomFactor + const defaultMax = 3.0 * defaultMaxRandomFactor + + tests := []struct { + name string + baseReachableTime time.Duration + minRandomFactor float32 + maxRandomFactor float32 + want time.Duration + }{ + { + name: "BaseReachableTime", + baseReachableTime: 2 * defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMax, + want: time.Duration((defaultMin + (defaultMax-defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + { + name: "MinRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMax, + maxRandomFactor: defaultMax, + want: time.Duration(defaultMax * float32(defaultBase)), + }, + { + name: "MaxRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: defaultMin, + maxRandomFactor: defaultMin, + want: time.Duration(defaultMin * float32(defaultBase)), + }, + { + name: "BothRandomFactor", + baseReachableTime: defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(defaultBase)), + }, + { + name: "BaseReachableTimeAndBothRandomFactors", + baseReachableTime: 2 * defaultBase, + minRandomFactor: 2 * defaultMin, + maxRandomFactor: 2 * defaultMax, + want: time.Duration((2*defaultMin + (2*defaultMax-2*defaultMin)*defaultFakeRandomNum) * float32(2*defaultBase)), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := stack.DefaultNUDConfigurations() + c.BaseReachableTime = defaultBase + c.MinRandomFactor = defaultMin + c.MaxRandomFactor = defaultMax + + // A fake random number generator is used to ensure deterministic + // results. + rng := fakeRand{ + num: defaultFakeRandomNum, + } + s := stack.NewNUDState(c, &rng) + old := s.ReachableTime() + + if got, want := s.ReachableTime(), old; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Check for recomputation when changing the min random factor, the max + // random factor, the base reachability time, or any permutation of those + // three options. + c.BaseReachableTime = test.baseReachableTime + c.MinRandomFactor = test.minRandomFactor + c.MaxRandomFactor = test.maxRandomFactor + s.SetConfig(c) + + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + + // Verify that ReachableTime isn't recomputed when none of the + // configuration options change. The random factor is changed so that if + // a recompution were to occur, ReachableTime would change. + rng.num = defaultFakeRandomNum / 2.0 + if got, want := s.ReachableTime(), test.want; got != want { + t.Errorf("got ReachableTime = %q, want = %q", got, want) + } + }) + } +} diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 1b5da6017..5d6865e35 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -14,6 +14,7 @@ package stack import ( + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" ) @@ -24,7 +25,7 @@ import ( // multiple endpoints. Clone() should be called in such cases so that // modifications to the Data field do not affect other copies. type PacketBuffer struct { - _ noCopy + _ sync.NoCopy // PacketBufferEntry is used to build an intrusive list of // PacketBuffers. @@ -78,6 +79,10 @@ type PacketBuffer struct { // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. NatDone bool + + // PktType indicates the SockAddrLink.PacketType of the packet as defined in + // https://www.man7.org/linux/man-pages/man7/packet.7.html. + PktType tcpip.PacketType } // Clone makes a copy of pk. It clones the Data field, which creates a new @@ -102,14 +107,3 @@ func (pk *PacketBuffer) Clone() *PacketBuffer { NatDone: pk.NatDone, } } - -// noCopy may be embedded into structs which must not be copied -// after the first use. -// -// See https://golang.org/issues/8005#issuecomment-190753527 -// for details. -type noCopy struct{} - -// Lock is a no-op used by -copylocks checker from `go vet`. -func (*noCopy) Lock() {} -func (*noCopy) Unlock() {} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 5cbc946b6..8604c4259 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/waiter" ) @@ -51,8 +52,11 @@ type TransportEndpointID struct { type ControlType int // The following are the allowed values for ControlType values. +// TODO(http://gvisor.dev/issue/3210): Support time exceeded messages. const ( - ControlPacketTooBig ControlType = iota + ControlNetworkUnreachable ControlType = iota + ControlNoRoute + ControlPacketTooBig ControlPortUnreachable ControlUnknown ) @@ -329,8 +333,7 @@ type NetworkProtocol interface { } // NetworkDispatcher contains the methods used by the network stack to deliver -// packets to the appropriate network endpoint after it has been handled by -// the data link layer. +// inbound/outbound packets to the appropriate network/packet(if any) endpoints. type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. @@ -341,6 +344,16 @@ type NetworkDispatcher interface { // // DeliverNetworkPacket takes ownership of pkt. DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) + + // DeliverOutboundPacket is called by link layer when a packet is being + // sent out. + // + // pkt.LinkHeader may or may not be set before calling + // DeliverOutboundPacket. Some packets do not have link headers (e.g. + // packets sent via loopback), and won't have the field set. + // + // DeliverOutboundPacket takes ownership of pkt. + DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // LinkEndpointCapabilities is the type associated with the capabilities @@ -436,6 +449,15 @@ type LinkEndpoint interface { // Wait will not block if the endpoint hasn't started any goroutines // yet, even if it might later. Wait() + + // ARPHardwareType returns the ARPHRD_TYPE of the link endpoint. + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/include/uapi/linux/if_arp.h#L30 + ARPHardwareType() header.ARPHardwareType + + // AddHeader adds a link layer header to pkt if required. + AddHeader(local, remote tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -456,12 +478,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. - // The request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts + // the request on the local network if remoteLinkAddr is the zero value. The + // request is sent on linkEP with localAddr as the source. // // A valid response will cause the discovery protocol's network // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index d65f8049e..91e0110f1 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -48,6 +48,10 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping + + // directedBroadcast indicates whether this route is sending a directed + // broadcast packet. + directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -275,6 +279,12 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } +// IsBroadcast returns true if the route is to send a broadcast packet. +func (r *Route) IsBroadcast() bool { + // Only IPv4 has a notion of broadcast. + return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast +} + // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 51abe32a7..3f07e4159 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -425,6 +425,7 @@ type Stack struct { handleLocal bool // tables are the iptables packet filtering and manipulation rules. + // TODO(gvisor.dev/issue/170): S/R this field. tables *IPTables // resumableEndpoints is a list of endpoints that need to be resumed if the @@ -444,6 +445,9 @@ type Stack struct { // ndpConfigs is the default NDP configurations used by interfaces. ndpConfigs NDPConfigurations + // nudConfigs is the default NUD configurations used by interfaces. + nudConfigs NUDConfigurations + // autoGenIPv6LinkLocal determines whether or not the stack will attempt // to auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. @@ -453,6 +457,10 @@ type Stack struct { // integrator NDP related events. ndpDisp NDPDispatcher + // nudDisp is the NUD event dispatcher that is used to send the netstack + // integrator NUD related events. + nudDisp NUDDispatcher + // uniqueIDGenerator is a generator of unique identifiers. uniqueIDGenerator UniqueID @@ -471,6 +479,14 @@ type Stack struct { // randomGenerator is an injectable pseudo random generator that can be // used when a random number is required. randomGenerator *mathrand.Rand + + // sendBufferSize holds the min/default/max send buffer sizes for + // endpoints other than TCP. + sendBufferSize SendBufferSizeOption + + // receiveBufferSize holds the min/default/max receive buffer sizes for + // endpoints other than TCP. + receiveBufferSize ReceiveBufferSizeOption } // UniqueID is an abstract generator of unique identifiers. @@ -509,6 +525,9 @@ type Options struct { // before assigning an address to a NIC. NDPConfigs NDPConfigurations + // NUDConfigs is the default NUD configurations used by interfaces. + NUDConfigs NUDConfigurations + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to // auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. @@ -527,6 +546,10 @@ type Options struct { // receive NDP related events. NDPDisp NDPDispatcher + // NUDDisp is the NUD event dispatcher that an integrator can provide to + // receive NUD related events. + NUDDisp NUDDispatcher + // RawFactory produces raw endpoints. Raw endpoints are enabled only if // this is non-nil. RawFactory RawFactory @@ -661,6 +684,8 @@ func New(opts Options) *Stack { // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() + opts.NUDConfigs.resetInvalidFields() + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), @@ -676,13 +701,25 @@ func New(opts Options) *Stack { icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, + nudConfigs: opts.NUDConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, + nudDisp: opts.NUDDisp, opaqueIIDOpts: opts.OpaqueIIDOpts, tempIIDSeed: opts.TempIIDSeed, forwarder: newForwardQueue(), randomGenerator: mathrand.New(randSrc), + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, + receiveBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultBufferSize, + Max: DefaultMaxBufferSize, + }, } // Add specified network protocols. @@ -709,6 +746,11 @@ func New(opts Options) *Stack { return s } +// newJob returns a tcpip.Job using the Stack clock. +func (s *Stack) newJob(l sync.Locker, f func()) *tcpip.Job { + return tcpip.NewJob(s.clock, l, f) +} + // UniqueID returns a unique identifier. func (s *Stack) UniqueID() uint64 { return s.uniqueIDGenerator.UniqueID() @@ -782,9 +824,10 @@ func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h f } } -// NowNanoseconds implements tcpip.Clock.NowNanoseconds. -func (s *Stack) NowNanoseconds() int64 { - return s.clock.NowNanoseconds() +// Clock returns the Stack's clock for retrieving the current time and +// scheduling work. +func (s *Stack) Clock() tcpip.Clock { + return s.clock } // Stats returns a mutable copy of the current stats. @@ -1076,6 +1119,11 @@ type NICInfo struct { // Context is user-supplied data optionally supplied in CreateNICWithOptions. // See type NICOptions for more details. Context NICContext + + // ARPHardwareType holds the ARP Hardware type of the NIC. This is the + // value sent in haType field of an ARP Request sent by this NIC and the + // value expected in the haType field of an ARP response. + ARPHardwareType header.ARPHardwareType } // HasNIC returns true if the NICID is defined in the stack. @@ -1107,6 +1155,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { MTU: nic.linkEP.MTU(), Stats: nic.stats, Context: nic.context, + ARPHardwareType: nic.linkEP.ARPHardwareType(), } } return nics @@ -1253,9 +1302,9 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isBroadcast := remoteAddr == header.IPv4Broadcast + isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.enabled() { if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil { @@ -1276,9 +1325,16 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - if needRoute { - r.NextHop = route.Gateway + r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) + + if len(route.Gateway) > 0 { + if needRoute { + r.NextHop = route.Gateway + } + } else if r.directedBroadcast { + r.RemoteLinkAddress = header.EthernetBroadcastAddress } + return r, nil } } @@ -1408,6 +1464,12 @@ func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.N return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } +// CheckRegisterTransportEndpoint checks if an endpoint can be registered with +// the stack transport dispatcher. +func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice) +} + // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { @@ -1825,10 +1887,38 @@ func (s *Stack) SetNDPConfigurations(id tcpip.NICID, c NDPConfigurations) *tcpip } nic.setNDPConfigs(c) - return nil } +// NUDConfigurations gets the per-interface NUD configurations. +func (s *Stack) NUDConfigurations(id tcpip.NICID) (NUDConfigurations, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return NUDConfigurations{}, tcpip.ErrUnknownNICID + } + + return nic.NUDConfigs() +} + +// SetNUDConfigurations sets the per-interface NUD configurations. +// +// Note, if c contains invalid NUD configuration values, it will be fixed to +// use default values for the erroneous values. +func (s *Stack) SetNUDConfigurations(id tcpip.NICID, c NUDConfigurations) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[id] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.setNUDConfigs(c) +} + // HandleNDPRA provides a NIC with ID id a validated NDP Router Advertisement // message that it needs to handle. func (s *Stack) HandleNDPRA(id tcpip.NICID, ip tcpip.Address, ra header.NDPRouterAdvert) *tcpip.Error { diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go new file mode 100644 index 000000000..0b093e6c5 --- /dev/null +++ b/pkg/tcpip/stack/stack_options.go @@ -0,0 +1,106 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import "gvisor.dev/gvisor/pkg/tcpip" + +const ( + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4 KiB + + // DefaultBufferSize is the default size of the send/recv buffer for a + // transport endpoint. + DefaultBufferSize = 212 << 10 // 212 KiB + + // DefaultMaxBufferSize is the default maximum permitted size of a + // send/receive buffer. + DefaultMaxBufferSize = 4 << 20 // 4 MiB +) + +// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to +// get/set the default, min and max receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// SetOption allows setting stack wide options. +func (s *Stack) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SendBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.sendBufferSize = v + s.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + // Make sure we don't allow lowering the buffer below minimum + // required for stack to work. + if v.Min < MinBufferSize { + return tcpip.ErrInvalidOptionValue + } + + if v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + + s.mu.Lock() + s.receiveBufferSize = v + s.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option allows retrieving stack wide options. +func (s *Stack) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SendBufferSizeOption: + s.mu.RLock() + *v = s.sendBufferSize + s.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + s.mu.RLock() + *v = s.receiveBufferSize + s.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 5aacbf53e..f22062889 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -27,6 +27,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -867,9 +868,9 @@ func TestRouteWithDownNIC(t *testing.T) { // Writes with Routes that use NIC1 after being brought up should // succeed. // - // TODO(b/147015577): Should we instead completely invalidate all - // Routes that were bound to a NIC that was brought down at some - // point? + // TODO(gvisor.dev/issue/1491): Should we instead completely + // invalidate all Routes that were bound to a NIC that was brought + // down at some point? if err := upFn(s, nicID1); err != nil { t.Fatalf("test.upFn(_, %d): %s", nicID1, err) } @@ -3338,3 +3339,305 @@ func TestDoDADWhenNICEnabled(t *testing.T) { t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix) } } + +func TestStackReceiveBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + rs stack.ReceiveBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.ReceiveBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.ReceiveBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.ReceiveBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.rs); err != tc.err { + t.Fatalf("s.SetOption(%#v) = %v, want: %v", tc.rs, err, tc.err) + } + var rs stack.ReceiveBufferSizeOption + if tc.err == nil { + if err := s.Option(&rs); err != nil { + t.Fatalf("s.Option(%#v) = %v, want: nil", rs, err) + } + if got, want := rs, tc.rs; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestStackSendBufferSizeOption(t *testing.T) { + const sMin = stack.MinBufferSize + testCases := []struct { + name string + ss stack.SendBufferSizeOption + err *tcpip.Error + }{ + // Invalid configurations. + {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue}, + {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue}, + + // Valid Configurations + {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil}, + {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil}, + {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil}, + {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := stack.New(stack.Options{}) + defer s.Close() + if err := s.SetOption(tc.ss); err != tc.err { + t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err) + } + var ss stack.SendBufferSizeOption + if tc.err == nil { + if err := s.Option(&ss); err != nil { + t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err) + } + if got, want := ss, tc.ss; got != want { + t.Fatalf("s.Option(..) returned unexpected value got: %#v, want: %#v", got, want) + } + } + }) + } +} + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID1 = 1 + ) + + defaultAddr := tcpip.AddressWithPrefix{ + Address: header.IPv4Any, + PrefixLen: 0, + } + defaultSubnet := defaultAddr.Subnet() + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + expectedRoute stack.Route + }{ + // Broadcast to a locally attached subnet populates the broadcast MAC. + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: ipv4SubnetBcast, + RemoteLinkAddress: header.EthernetBroadcastAddress, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /31 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix31.Address, + RemoteAddress: ipv4Subnet31Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a locally attached /32 subnet does not populate the + // broadcast MAC. + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4AddrPrefix32.Address, + RemoteAddress: ipv4Subnet32Bcast, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv6Addr.Address, + RemoteAddress: ipv6SubnetBcast, + NetProto: header.IPv6ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to a remote subnet in the route table is send to the next-hop + // gateway. + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + // Broadcast to an unknown subnet follows the default route. Note that this + // is essentially just routing an unknown destination IP, because w/o any + // subnet prefix information a subnet broadcast address is just a normal IP. + { + name: "IPv4 Broadcast to unknown subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: defaultSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + expectedRoute: stack.Route{ + LocalAddress: ipv4Addr.Address, + RemoteAddress: remNetSubnetBcast, + NextHop: ipv4Gateway, + NetProto: header.IPv4ProtocolNumber, + Loop: stack.PacketOut, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + ep := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + if r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, test.remoteAddr, netProto, false /* multicastLoop */); err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, test.remoteAddr, netProto, err) + } else if diff := cmp.Diff(r, test.expectedRoute, cmpopts.IgnoreUnexported(r)); diff != "" { + t.Errorf("route mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 118b449d5..b902c6ca9 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -221,6 +221,18 @@ func (epsByNIC *endpointsByNIC) registerEndpoint(d *transportDemuxer, netProto t return multiPortEp.singleRegisterEndpoint(t, flags) } +func (epsByNIC *endpointsByNIC) checkEndpoint(d *transportDemuxer, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNIC.mu.RLock() + defer epsByNIC.mu.RUnlock() + + multiPortEp, ok := epsByNIC.endpoints[bindToDevice] + if !ok { + return nil + } + + return multiPortEp.singleCheckEndpoint(flags) +} + // unregisterEndpoint returns true if endpointsByNIC has to be unregistered. func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint, flags ports.Flags) bool { epsByNIC.mu.Lock() @@ -289,6 +301,17 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum return nil } +// checkEndpoint checks if an endpoint can be registered with the dispatcher. +func (d *transportDemuxer) checkEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + for _, n := range netProtos { + if err := d.singleCheckEndpoint(n, protocol, id, flags, bindToDevice); err != nil { + return err + } + } + + return nil +} + // multiPortEndpoint is a container for TransportEndpoints which are bound to // the same pair of address and port. endpointsArr always has at least one // element. @@ -380,7 +403,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p ep.mu.Lock() defer ep.mu.Unlock() - bits := flags.Bits() + bits := flags.Bits() & ports.MultiBindFlagMask if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. @@ -395,6 +418,22 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p return nil } +func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) *tcpip.Error { + ep.mu.RLock() + defer ep.mu.RUnlock() + + bits := flags.Bits() & ports.MultiBindFlagMask + + if len(ep.endpoints) != 0 { + // If it was previously bound, we need to check if we can bind again. + if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + return tcpip.ErrPortInUse + } + } + + return nil +} + // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports.Flags) bool { ep.mu.Lock() @@ -406,7 +445,7 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint, flags ports ep.endpoints[len(ep.endpoints)-1] = nil ep.endpoints = ep.endpoints[:len(ep.endpoints)-1] - ep.flags.DropRef(flags.Bits()) + ep.flags.DropRef(flags.Bits() & ports.MultiBindFlagMask) break } } @@ -439,6 +478,28 @@ func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocol return epsByNIC.registerEndpoint(d, netProto, protocol, ep, flags, bindToDevice) } +func (d *transportDemuxer) singleCheckEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error { + if id.RemotePort != 0 { + // SO_REUSEPORT only applies to bound/listening endpoints. + flags.LoadBalanced = false + } + + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return tcpip.ErrUnknownProtocol + } + + eps.mu.RLock() + defer eps.mu.RUnlock() + + epsByNIC, ok := eps.endpoints[id] + if !ok { + return nil + } + + return epsByNIC.checkEndpoint(d, netProto, protocol, flags, bindToDevice) +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 3ad130b23..a634b9b60 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -43,6 +43,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// Using header.IPv4AddressSize would cause an import cycle. +const ipv4AddressSize = 4 + // Error represents an error in the netstack error space. Using a special type // ensures that errors outside of this space are not accidentally introduced. // @@ -192,7 +195,7 @@ func (e ErrSaveRejection) Error() string { return "save rejected due to unsupported networking state: " + e.Err.Error() } -// A Clock provides the current time. +// A Clock provides the current time and schedules work for execution. // // Times returned by a Clock should always be used for application-visible // time. Only monotonic times should be used for netstack internal timekeeping. @@ -203,6 +206,31 @@ type Clock interface { // NowMonotonic returns a monotonic time value. NowMonotonic() int64 + + // AfterFunc waits for the duration to elapse and then calls f in its own + // goroutine. It returns a Timer that can be used to cancel the call using + // its Stop method. + AfterFunc(d time.Duration, f func()) Timer +} + +// Timer represents a single event. A Timer must be created with +// Clock.AfterFunc. +type Timer interface { + // Stop prevents the Timer from firing. It returns true if the call stops the + // timer, false if the timer has already expired or been stopped. + // + // If Stop returns false, then the timer has already expired and the function + // f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop + // does not wait for f to complete before returning. If the caller needs to + // know whether f is completed, it must coordinate with f explicitly. + Stop() bool + + // Reset changes the timer to expire after duration d. + // + // Reset should be invoked only on stopped or expired timers. If the timer is + // known to have expired, Reset can be used directly. Otherwise, the caller + // must coordinate with the function f of Clock.AfterFunc(d, f). + Reset(d time.Duration) } // Address is a byte slice cast as a string that represents the address of a @@ -295,6 +323,29 @@ func (s *Subnet) Broadcast() Address { return Address(addr) } +// IsBroadcast returns true if the address is considered a broadcast address. +func (s *Subnet) IsBroadcast(address Address) bool { + // Only IPv4 supports the notion of a broadcast address. + if len(address) != ipv4AddressSize { + return false + } + + // Normally, we would just compare address with the subnet's broadcast + // address but there is an exception where a simple comparison is not + // correct. This exception is for /31 and /32 IPv4 subnets where all + // addresses are considered valid host addresses. + // + // For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that + // both addresses in a /31 subnet "MUST be interpreted as host addresses." + // + // For /32, the case is a bit more vague. RFC 3021 makes no mention of /32 + // subnets. However, the same reasoning applies - if an exception is not + // made, then there do not exist any host addresses in a /32 subnet. RFC + // 4632 Section 3.1 also vaguely implies this interpretation by referring + // to addresses in /32 subnets as "host routes." + return s.Prefix() <= 30 && s.Broadcast() == address +} + // Equal returns true if s equals o. // // Needed to use cmp.Equal on Subnet as its fields are unexported. @@ -316,6 +367,28 @@ const ( ShutdownWrite ) +// PacketType is used to indicate the destination of the packet. +type PacketType uint8 + +const ( + // PacketHost indicates a packet addressed to the local host. + PacketHost PacketType = iota + + // PacketOtherHost indicates an outgoing packet addressed to + // another host caught by a NIC in promiscuous mode. + PacketOtherHost + + // PacketOutgoing for a packet originating from the local host + // that is looped back to a packet socket. + PacketOutgoing + + // PacketBroadcast indicates a link layer broadcast packet. + PacketBroadcast + + // PacketMulticast indicates a link layer multicast packet. + PacketMulticast +) + // FullAddress represents a full transport node address, as required by the // Connect() and Bind() methods. // @@ -549,6 +622,28 @@ type Endpoint interface { SetOwner(owner PacketOwner) } +// LinkPacketInfo holds Link layer information for a received packet. +// +// +stateify savable +type LinkPacketInfo struct { + // Protocol is the NetworkProtocolNumber for the packet. + Protocol NetworkProtocolNumber + + // PktType is used to indicate the destination of the packet. + PktType PacketType +} + +// PacketEndpoint are additional methods that are only implemented by Packet +// endpoints. +type PacketEndpoint interface { + // ReadPacket reads a datagram/packet from the endpoint and optionally + // returns the sender and additional LinkPacketInfo. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) +} + // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo @@ -585,85 +680,108 @@ type WriteOptions struct { type SockOptBool int const ( - // BroadcastOption is used by SetSockOpt/GetSockOpt to specify whether - // datagram sockets are allowed to send packets to a broadcast address. + // BroadcastOption is used by SetSockOptBool/GetSockOptBool to specify + // whether datagram sockets are allowed to send packets to a broadcast + // address. BroadcastOption SockOptBool = iota - // CorkOption is used by SetSockOpt/GetSockOpt to specify if data should be - // held until segments are full by the TCP transport protocol. + // CorkOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be held until segments are full by the TCP transport + // protocol. CorkOption - // DelayOption is used by SetSockOpt/GetSockOpt to specify if data - // should be sent out immediately by the transport protocol. For TCP, - // it determines if the Nagle algorithm is on or off. + // DelayOption is used by SetSockOptBool/GetSockOptBool to specify if + // data should be sent out immediately by the transport protocol. For + // TCP, it determines if the Nagle algorithm is on or off. DelayOption - // KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether - // TCP keepalive is enabled for this socket. + // KeepaliveEnabledOption is used by SetSockOptBool/GetSockOptBool to + // specify whether TCP keepalive is enabled for this socket. KeepaliveEnabledOption - // MulticastLoopOption is used by SetSockOpt/GetSockOpt to specify whether - // multicast packets sent over a non-loopback interface will be looped back. + // MulticastLoopOption is used by SetSockOptBool/GetSockOptBool to + // specify whether multicast packets sent over a non-loopback interface + // will be looped back. MulticastLoopOption - // PasscredOption is used by SetSockOpt/GetSockOpt to specify whether - // SCM_CREDENTIALS socket control messages are enabled. + // NoChecksumOption is used by SetSockOptBool/GetSockOptBool to specify + // whether UDP checksum is disabled for this socket. + NoChecksumOption + + // PasscredOption is used by SetSockOptBool/GetSockOptBool to specify + // whether SCM_CREDENTIALS socket control messages are enabled. // // Only supported on Unix sockets. PasscredOption - // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. + // QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool. QuickAckOption - // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the - // IPV6_TCLASS ancillary message is passed with incoming packets. + // ReceiveTClassOption is used by SetSockOptBool/GetSockOptBool to + // specify if the IPV6_TCLASS ancillary message is passed with incoming + // packets. ReceiveTClassOption - // ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS - // ancillary message is passed with incoming packets. + // ReceiveTOSOption is used by SetSockOptBool/GetSockOptBool to specify + // if the TOS ancillary message is passed with incoming packets. ReceiveTOSOption - // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify - // if more inforamtion is provided with incoming packets such - // as interface index and address. + // ReceiveIPPacketInfoOption is used by SetSockOptBool/GetSockOptBool to + // specify if more inforamtion is provided with incoming packets such as + // interface index and address. ReceiveIPPacketInfoOption - // ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() - // should allow reuse of local address. + // ReuseAddressOption is used by SetSockOptBool/GetSockOptBool to + // specify whether Bind() should allow reuse of local address. ReuseAddressOption - // ReusePortOption is used by SetSockOpt/GetSockOpt to permit multiple sockets - // to be bound to an identical socket address. + // ReusePortOption is used by SetSockOptBool/GetSockOptBool to permit + // multiple sockets to be bound to an identical socket address. ReusePortOption - // V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6 - // socket is to be restricted to sending and receiving IPv6 packets only. + // V6OnlyOption is used by SetSockOptBool/GetSockOptBool to specify + // whether an IPv6 socket is to be restricted to sending and receiving + // IPv6 packets only. V6OnlyOption + + // IPHdrIncludedOption is used by SetSockOpt to indicate for a raw + // endpoint that all packets being written have an IP header and the + // endpoint should not attach an IP header. + IPHdrIncludedOption ) // SockOptInt represents socket options which values have the int type. type SockOptInt int const ( - // KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number - // of un-ACKed TCP keepalives that will be sent before the connection is - // closed. + // KeepaliveCountOption is used by SetSockOptInt/GetSockOptInt to + // specify the number of un-ACKed TCP keepalives that will be sent + // before the connection is closed. KeepaliveCountOption SockOptInt = iota - // IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS + // IPv4TOSOption is used by SetSockOptInt/GetSockOptInt to specify TOS // for all subsequent outgoing IPv4 packets from the endpoint. IPv4TOSOption - // IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS - // for all subsequent outgoing IPv6 packets from the endpoint. + // IPv6TrafficClassOption is used by SetSockOptInt/GetSockOptInt to + // specify TOS for all subsequent outgoing IPv6 packets from the + // endpoint. IPv6TrafficClassOption - // MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current - // Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. + // MaxSegOption is used by SetSockOptInt/GetSockOptInt to set/get the + // current Maximum Segment Size(MSS) value as specified using the + // TCP_MAXSEG option. MaxSegOption - // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default - // TTL value for multicast messages. The default is 1. + // MTUDiscoverOption is used to set/get the path MTU discovery setting. + // + // NOTE: Setting this option to any other value than PMTUDiscoveryDont + // is not supported and will fail as such, and getting this option will + // always return PMTUDiscoveryDont. + MTUDiscoverOption + + // MulticastTTLOption is used by SetSockOptInt/GetSockOptInt to control + // the default TTL value for multicast messages. The default is 1. MulticastTTLOption // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the @@ -682,26 +800,45 @@ const ( // number of unread bytes in the output buffer should be returned. SendQueueSizeOption - // TTLOption is used by SetSockOpt/GetSockOpt to control the default TTL/hop - // limit value for unicast messages. The default is protocol specific. + // TTLOption is used by SetSockOptInt/GetSockOptInt to control the + // default TTL/hop limit value for unicast messages. The default is + // protocol specific. // // A zero value indicates the default. TTLOption - // TCPSynCountOption is used by SetSockOpt/GetSockOpt to specify the number of - // SYN retransmits that TCP should send before aborting the attempt to - // connect. It cannot exceed 255. + // TCPSynCountOption is used by SetSockOptInt/GetSockOptInt to specify + // the number of SYN retransmits that TCP should send before aborting + // the attempt to connect. It cannot exceed 255. // // NOTE: This option is currently only stubbed out and is no-op. TCPSynCountOption - // TCPWindowClampOption is used by SetSockOpt/GetSockOpt to bound the size - // of the advertised window to this value. + // TCPWindowClampOption is used by SetSockOptInt/GetSockOptInt to bound + // the size of the advertised window to this value. // // NOTE: This option is currently only stubed out and is a no-op TCPWindowClampOption ) +const ( + // PMTUDiscoveryWant is a setting of the MTUDiscoverOption to use + // per-route settings. + PMTUDiscoveryWant int = iota + + // PMTUDiscoveryDont is a setting of the MTUDiscoverOption to disable + // path MTU discovery. + PMTUDiscoveryDont + + // PMTUDiscoveryDo is a setting of the MTUDiscoverOption to always do + // path MTU discovery. + PMTUDiscoveryDo + + // PMTUDiscoveryProbe is a setting of the MTUDiscoverOption to set DF + // but ignore path MTU. + PMTUDiscoveryProbe +) + // ErrorOption is used in GetSockOpt to specify that the last error reported by // the endpoint should be cleared and returned. type ErrorOption struct{} @@ -740,7 +877,7 @@ type CongestionControlOption string // control algorithms. type AvailableCongestionControlOption string -// buffer moderation. +// ModerateReceiveBufferOption is used by buffer moderation. type ModerateReceiveBufferOption bool // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the @@ -813,33 +950,11 @@ type OutOfBandInlineOption int // a default TTL. type DefaultTTLOption uint8 -// StackSACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type StackSACKEnabled bool - -// StackDelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type StackDelayEnabled bool +// SocketDetachFilterOption is used by SetSockOpt to detach a previously attached +// classic BPF filter on a given endpoint. +type SocketDetachFilterOption int -// StackSendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max send buffer sizes. -type StackSendBufferSizeOption struct { - Min int - Default int - Max int -} - -// StackReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// receive buffer sizes. -type StackReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - -// -// IPPacketInfo is the message struture for IP_PKTINFO. +// IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable type IPPacketInfo struct { @@ -1224,6 +1339,12 @@ type UDPStats struct { // PacketSendErrors is the number of datagrams failed to be sent. PacketSendErrors *StatCounter + + // ChecksumErrors is the number of datagrams dropped due to bad checksums. + ChecksumErrors *StatCounter + + // InvalidSourceAddress is the number of invalid sourced datagrams dropped. + InvalidSourceAddress *StatCounter } // Stats holds statistics about the networking stack. @@ -1267,6 +1388,9 @@ type ReceiveErrors struct { // ClosedReceiver is the number of received packets dropped because // of receiving endpoint state being closed. ClosedReceiver StatCounter + + // ChecksumErrors is the number of packets dropped due to bad checksums. + ChecksumErrors StatCounter } // SendErrors collects packet send errors within the transport layer for diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index 7f172f978..f32d58091 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -20,7 +20,7 @@ package tcpip import ( - _ "time" // Used with go:linkname. + "time" // Used with go:linkname. _ "unsafe" // Required for go:linkname. ) @@ -45,3 +45,31 @@ func (*StdClock) NowMonotonic() int64 { _, _, mono := now() return mono } + +// AfterFunc implements Clock.AfterFunc. +func (*StdClock) AfterFunc(d time.Duration, f func()) Timer { + return &stdTimer{ + t: time.AfterFunc(d, f), + } +} + +type stdTimer struct { + t *time.Timer +} + +var _ Timer = (*stdTimer)(nil) + +// Stop implements Timer.Stop. +func (st *stdTimer) Stop() bool { + return st.t.Stop() +} + +// Reset implements Timer.Reset. +func (st *stdTimer) Reset(d time.Duration) { + st.t.Reset(d) +} + +// NewStdTimer returns a Timer implemented with the time package. +func NewStdTimer(t *time.Timer) Timer { + return &stdTimer{t: t} +} diff --git a/pkg/tcpip/timer.go b/pkg/tcpip/timer.go index 59f3b391f..f1dd7c310 100644 --- a/pkg/tcpip/timer.go +++ b/pkg/tcpip/timer.go @@ -15,54 +15,54 @@ package tcpip import ( - "sync" "time" + + "gvisor.dev/gvisor/pkg/sync" ) -// cancellableTimerInstance is a specific instance of CancellableTimer. +// jobInstance is a specific instance of Job. // -// Different instances are created each time CancellableTimer is Reset so each -// timer has its own earlyReturn signal. This is to address a bug when a -// CancellableTimer is stopped and reset in quick succession resulting in a -// timer instance's earlyReturn signal being affected or seen by another timer -// instance. +// Different instances are created each time Job is scheduled so each timer has +// its own earlyReturn signal. This is to address a bug when a Job is stopped +// and reset in quick succession resulting in a timer instance's earlyReturn +// signal being affected or seen by another timer instance. // // Consider the following sceneario where timer instances share a common // earlyReturn signal (T1 creates, stops and resets a Cancellable timer under a // lock L; T2, T3, T4 and T5 are goroutines that handle the first (A), second // (B), third (C), and fourth (D) instance of the timer firing, respectively): // T1: Obtain L -// T1: Create a new CancellableTimer w/ lock L (create instance A) +// T1: Create a new Job w/ lock L (create instance A) // T2: instance A fires, blocked trying to obtain L. // T1: Attempt to stop instance A (set earlyReturn = true) -// T1: Reset timer (create instance B) +// T1: Schedule timer (create instance B) // T3: instance B fires, blocked trying to obtain L. // T1: Attempt to stop instance B (set earlyReturn = true) -// T1: Reset timer (create instance C) +// T1: Schedule timer (create instance C) // T4: instance C fires, blocked trying to obtain L. // T1: Attempt to stop instance C (set earlyReturn = true) -// T1: Reset timer (create instance D) +// T1: Schedule timer (create instance D) // T5: instance D fires, blocked trying to obtain L. // T1: Release L // -// Now that T1 has released L, any of the 4 timer instances can take L and check -// earlyReturn. If the timers simply check earlyReturn and then do nothing -// further, then instance D will never early return even though it was not -// requested to stop. If the timers reset earlyReturn before early returning, -// then all but one of the timers will do work when only one was expected to. -// If CancellableTimer resets earlyReturn when resetting, then all the timers +// Now that T1 has released L, any of the 4 timer instances can take L and +// check earlyReturn. If the timers simply check earlyReturn and then do +// nothing further, then instance D will never early return even though it was +// not requested to stop. If the timers reset earlyReturn before early +// returning, then all but one of the timers will do work when only one was +// expected to. If Job resets earlyReturn when resetting, then all the timers // will fire (again, when only one was expected to). // // To address the above concerns the simplest solution was to give each timer // its own earlyReturn signal. -type cancellableTimerInstance struct { - timer *time.Timer +type jobInstance struct { + timer Timer // Used to inform the timer to early return when it gets stopped while the // lock the timer tries to obtain when fired is held (T1 is a goroutine that // tries to cancel the timer and T2 is the goroutine that handles the timer // firing): - // T1: Obtain the lock, then call StopLocked() + // T1: Obtain the lock, then call Cancel() // T2: timer fires, and gets blocked on obtaining the lock // T1: Releases lock // T2: Obtains lock does unintended work @@ -73,27 +73,33 @@ type cancellableTimerInstance struct { earlyReturn *bool } -// stop stops the timer instance t from firing if it hasn't fired already. If it +// stop stops the job instance j from firing if it hasn't fired already. If it // has fired and is blocked at obtaining the lock, earlyReturn will be set to // true so that it will early return when it obtains the lock. -func (t *cancellableTimerInstance) stop() { - if t.timer != nil { - t.timer.Stop() - *t.earlyReturn = true +func (j *jobInstance) stop() { + if j.timer != nil { + j.timer.Stop() + *j.earlyReturn = true } } -// CancellableTimer is a timer that does some work and can be safely cancelled -// when it fires at the same time some "related work" is being done. +// Job represents some work that can be scheduled for execution. The work can +// be safely cancelled when it fires at the same time some "related work" is +// being done. // // The term "related work" is defined as some work that needs to be done while // holding some lock that the timer must also hold while doing some work. // -// Note, it is not safe to copy a CancellableTimer as its timer instance creates -// a closure over the address of the CancellableTimer. -type CancellableTimer struct { +// Note, it is not safe to copy a Job as its timer instance creates +// a closure over the address of the Job. +type Job struct { + _ sync.NoCopy + + // The clock used to schedule the backing timer + clock Clock + // The active instance of a cancellable timer. - instance cancellableTimerInstance + instance jobInstance // locker is the lock taken by the timer immediately after it fires and must // be held when attempting to stop the timer. @@ -110,75 +116,91 @@ type CancellableTimer struct { fn func() } -// StopLocked prevents the Timer from firing if it has not fired already. +// Cancel prevents the Job from executing if it has not executed already. // -// If the timer is blocked on obtaining the t.locker lock when StopLocked is -// called, it will early return instead of calling t.fn. +// Cancel requires appropriate locking to be in place for any resources managed +// by the Job. If the Job is blocked on obtaining the lock when Cancel is +// called, it will early return. // // Note, t will be modified. // -// t.locker MUST be locked. -func (t *CancellableTimer) StopLocked() { - t.instance.stop() +// j.locker MUST be locked. +func (j *Job) Cancel() { + j.instance.stop() // Nothing to do with the stopped instance anymore. - t.instance = cancellableTimerInstance{} + j.instance = jobInstance{} } -// Reset changes the timer to expire after duration d. +// Schedule schedules the Job for execution after duration d. This can be +// called on cancelled or completed Jobs to schedule them again. // -// Note, t will be modified. +// Schedule should be invoked only on unscheduled, cancelled, or completed +// Jobs. To be safe, callers should always call Cancel before calling Schedule. // -// Reset should only be called on stopped or expired timers. To be safe, callers -// should always call StopLocked before calling Reset. -func (t *CancellableTimer) Reset(d time.Duration) { +// Note, j will be modified. +func (j *Job) Schedule(d time.Duration) { // Create a new instance. earlyReturn := false // Capture the locker so that updating the timer does not cause a data race // when a timer fires and tries to obtain the lock (read the timer's locker). - locker := t.locker - t.instance = cancellableTimerInstance{ - timer: time.AfterFunc(d, func() { + locker := j.locker + j.instance = jobInstance{ + timer: j.clock.AfterFunc(d, func() { locker.Lock() defer locker.Unlock() if earlyReturn { // If we reach this point, it means that the timer fired while another - // goroutine called StopLocked while it had the lock. Simply return - // here and do nothing further. + // goroutine called Cancel while it had the lock. Simply return here + // and do nothing further. earlyReturn = false return } - t.fn() + j.fn() }), earlyReturn: &earlyReturn, } } -// Lock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Lock() {} - -// Unlock is a no-op used by the copylocks checker from go vet. -// -// See CancellableTimer for details about why it shouldn't be copied. -// -// See https://github.com/golang/go/issues/8005#issuecomment-190753527 for more -// details about the copylocks checker. -func (*CancellableTimer) Unlock() {} - -// NewCancellableTimer returns an unscheduled CancellableTimer with the given -// locker and fn. -// -// fn MUST NOT attempt to lock locker. -// -// Callers must call Reset to schedule the timer to fire. -func NewCancellableTimer(locker sync.Locker, fn func()) *CancellableTimer { - return &CancellableTimer{locker: locker, fn: fn} +// NewJob returns a new Job that can be used to schedule f to run in its own +// gorountine. l will be locked before calling f then unlocked after f returns. +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// message = "bar" +// mu.Unlock() +// +// // Output: bar +// +// f MUST NOT attempt to lock l. +// +// l MUST be locked prior to calling the returned job's Cancel(). +// +// var clock tcpip.StdClock +// var mu sync.Mutex +// message := "foo" +// job := tcpip.NewJob(&clock, &mu, func() { +// fmt.Println(message) +// }) +// job.Schedule(time.Second) +// +// mu.Lock() +// job.Cancel() +// mu.Unlock() +func NewJob(c Clock, l sync.Locker, f func()) *Job { + return &Job{ + clock: c, + locker: l, + fn: f, + } } diff --git a/pkg/tcpip/timer_test.go b/pkg/tcpip/timer_test.go index b4940e397..a82384c49 100644 --- a/pkg/tcpip/timer_test.go +++ b/pkg/tcpip/timer_test.go @@ -28,8 +28,8 @@ const ( longDuration = 1 * time.Second ) -func TestCancellableTimerReassignment(t *testing.T) { - var timer tcpip.CancellableTimer +func TestJobReschedule(t *testing.T) { + var clock tcpip.StdClock var wg sync.WaitGroup var lock sync.Mutex @@ -43,26 +43,27 @@ func TestCancellableTimerReassignment(t *testing.T) { // that has an active timer (even if it has been stopped as a stopped // timer may be blocked on a lock before it can check if it has been // stopped while another goroutine holds the same lock). - timer = *tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { wg.Done() }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) lock.Unlock() }() } wg.Wait() } -func TestCancellableTimerFire(t *testing.T) { +func TestJobExecution(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -82,17 +83,18 @@ func TestCancellableTimerFire(t *testing.T) { func TestCancellableTimerResetFromLongDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(middleDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(middleDuration) lock.Lock() - timer.StopLocked() + job.Cancel() lock.Unlock() - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -109,16 +111,17 @@ func TestCancellableTimerResetFromLongDuration(t *testing.T) { } } -func TestCancellableTimerResetFromShortDuration(t *testing.T) { +func TestJobRescheduleFromShortDuration(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() // Wait for timer to fire if it wasn't correctly stopped. @@ -128,7 +131,7 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { case <-time.After(middleDuration): } - timer.Reset(shortDuration) + job.Schedule(shortDuration) // Wait for timer to fire. select { @@ -145,17 +148,18 @@ func TestCancellableTimerResetFromShortDuration(t *testing.T) { } } -func TestCancellableTimerImmediatelyStop(t *testing.T) { +func TestJobImmediatelyCancel(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) for i := 0; i < 1000; i++ { lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() } @@ -167,25 +171,26 @@ func TestCancellableTimerImmediatelyStop(t *testing.T) { } } -func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { +func TestJobCancelledRescheduleWithoutLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) - timer.StopLocked() + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) + job.Cancel() lock.Unlock() for i := 0; i < 10; i++ { - timer.Reset(middleDuration) + job.Schedule(middleDuration) lock.Lock() // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration * 2) - timer.StopLocked() + job.Cancel() lock.Unlock() } @@ -201,17 +206,18 @@ func TestCancellableTimerStoppedResetWithoutLock(t *testing.T) { func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { // Sleep until the timer fires and gets blocked trying to take the lock. time.Sleep(middleDuration) - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() @@ -230,18 +236,19 @@ func TestManyCancellableTimerResetAfterBlockedOnLock(t *testing.T) { } } -func TestManyCancellableTimerResetUnderLock(t *testing.T) { +func TestManyJobReschedulesUnderLock(t *testing.T) { t.Parallel() - ch := make(chan struct{}) + var clock tcpip.StdClock var lock sync.Mutex + ch := make(chan struct{}) lock.Lock() - timer := tcpip.NewCancellableTimer(&lock, func() { ch <- struct{}{} }) - timer.Reset(shortDuration) + job := tcpip.NewJob(&clock, &lock, func() { ch <- struct{}{} }) + job.Schedule(shortDuration) for i := 0; i < 10; i++ { - timer.StopLocked() - timer.Reset(shortDuration) + job.Cancel() + job.Schedule(shortDuration) } lock.Unlock() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..4612be4e7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } @@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index baf08eda6..df478115d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,6 +25,8 @@ package packet import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -43,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -71,11 +76,17 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } @@ -132,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -158,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } @@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } @@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } @@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // HandlePacket implements stack.PacketEndpoint.HandlePacket. @@ -315,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. + // TODO(gvisor.dev/issue/173): Return network protocol. if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader) @@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6a7977259..f85a68554 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -94,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if netProto != header.IPv4ProtocolNumber { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -108,16 +109,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(transProto, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(transProto, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -215,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -258,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -319,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrNoRoute } - // We don't support IPv6 yet, so this has to be an IPv4 address. - if len(opts.To.Addr) != header.IPv4AddressSize { - e.mu.RUnlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) @@ -354,17 +351,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch e.NetProto { - case header.IPv4ProtocolNumber: - if !e.associated { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - break + if e.hdrIncluded { + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { + return 0, nil, err } - + } else { hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, @@ -373,9 +366,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, }); err != nil { return 0, nil, err } - - default: - return 0, nil, tcpip.ErrUnknownProtocol } return int64(len(payloadBytes)), nil, nil @@ -400,11 +390,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // We don't support IPv6 yet. - if len(addr.Addr) != header.IPv4AddressSize { - return tcpip.ErrInvalidEndpointState - } - nic := addr.NIC if e.bound { if e.BindNICID == 0 { @@ -470,14 +455,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // Callers must provide an IPv4 address or no network address (for - // binding to a NIC, but not an address). - if len(addr.Addr) != 0 && len(addr.Addr) != 4 { - return tcpip.ErrInvalidEndpointState - } - // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -527,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -541,9 +533,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) } if v > ss.Max { v = ss.Max @@ -559,9 +551,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) } if v > rs.Max { v = rs.Max @@ -596,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -635,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -680,12 +685,22 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - headers := append(buffer.View(nil), pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) - combinedVV := headers.ToVectorisedView() + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) + headers = append(headers, pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 6baeda8e4..e860ee484 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -49,6 +49,7 @@ go_library( "segment_heap.go", "segment_queue.go", "segment_state.go", + "segment_unsafe.go", "snd.go", "snd_state.go", "tcp_endpoint_list.go", @@ -86,6 +87,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7679fe169..6e00e5526 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -199,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. The endpoint is returned -// with n.mu held. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -227,22 +225,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - // Lock the endpoint before registering to ensure that no out of - // band changes are possible due to incoming packets etc till - // the endpoint is done initializing. - n.mu.Lock() - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { - n.mu.Unlock() - n.Close() - return nil, err - } - - n.isRegistered = true - n.registeredReusePort = n.reusePort - - return n, nil + return n } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -253,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) - if err != nil { - return nil, err - } + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. @@ -264,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) @@ -284,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // to the newly created endpoint. l.listenEP.propagateInheritableOptionsLocked(ep) + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + ep.notifyAborted() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -374,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // Precondition: e.mu and n.mu must be held. func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -568,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } - // Propagate any inheritable options from the listening endpoint - // to the newly created endpoint. - e.propagateInheritableOptionsLocked(n) + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + n.isRegistered = true // clear the tsOffset for the newly created // endpoint as the Timestamp was already diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 377643b82..1798510bc 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } } // Wait for notification. @@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error { // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, func() { - resendWaker.Assert() - }) + rt := time.AfterFunc(timeOut, resendWaker.Assert) defer rt.Stop() // Set up the wakers. @@ -521,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled tcpip.StackSACKEnabled + var sackEnabled SACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -1050,8 +1054,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { panic("current endpoint not removed from demuxer, enqueing segments to itself") } - if ep.(*endpoint).enqueueSegment(s) { - ep.(*endpoint).newSegmentWaker.Assert() + if ep := ep.(*endpoint); ep.enqueueSegment(s) { + ep.newSegmentWaker.Assert() } } @@ -1120,7 +1124,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState() == StateClose || e.EndpointState() == StateError { + if e.EndpointState().closed() { return nil } s := e.segmentQueue.dequeue() @@ -1440,9 +1444,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { - closeWaker.Assert() - }) + closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } } @@ -1460,7 +1462,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return err } } - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + if !e.EndpointState().closed() { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1526,7 +1528,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } loop: - for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { + for { + switch e.EndpointState() { + case StateTimeWait, StateClose, StateError: + break loop + } + e.mu.Unlock() v, _ := s.Fetch(true) e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 047704c80..98aecab9e 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -15,6 +15,8 @@ package tcp import ( + "encoding/binary" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -66,89 +68,68 @@ func (q *epQueue) empty() bool { // processor is responsible for processing packets queued to a tcp endpoint. type processor struct { epQ epQueue + sleeper sleep.Sleeper newEndpointWaker sleep.Waker closeWaker sleep.Waker - id int - wg sync.WaitGroup -} - -func newProcessor(id int) *processor { - p := &processor{ - id: id, - } - p.wg.Add(1) - go p.handleSegments() - return p } func (p *processor) close() { p.closeWaker.Assert() } -func (p *processor) wait() { - p.wg.Wait() -} - func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) p.newEndpointWaker.Assert() } -func (p *processor) handleSegments() { - const newEndpointWaker = 1 - const closeWaker = 2 - s := sleep.Sleeper{} - s.AddWaker(&p.newEndpointWaker, newEndpointWaker) - s.AddWaker(&p.closeWaker, closeWaker) - defer s.Done() +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + for { - id, ok := s.Fetch(true) - if ok && id == closeWaker { - p.wg.Done() - return + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break } - for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } if ep.segmentQueue.empty() { continue } - // If socket has transitioned out of connected state - // then just let the worker handle the packet. + // If socket has transitioned out of connected state then just let the + // worker handle the packet. // - // NOTE: We read this outside of e.mu lock which means - // that by the time we get to handleSegments the - // endpoint may not be in ESTABLISHED. But this should - // be fine as all normal shutdown states are handled by - // handleSegments and if the endpoint moves to a - // CLOSED/ERROR state then handleSegments is a noop. - if ep.EndpointState() != StateEstablished { - ep.newSegmentWaker.Assert() - continue - } - - if !ep.mu.TryLock() { - ep.newSegmentWaker.Assert() - continue - } - // If the endpoint is in a connected state then we do - // direct delivery to ensure low latency and avoid - // scheduler interactions. - if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { - // Send any active resets if required. - if err != nil { + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) } - ep.notifyProtocolGoroutine(notifyTickleWorker) ep.mu.Unlock() - continue - } - - if !ep.segmentQueue.empty() { - p.epQ.enqueue(ep) + } else { + ep.newSegmentWaker.Assert() } - - ep.mu.Unlock() } } } @@ -159,31 +140,36 @@ func (p *processor) handleSegments() { // hash of the endpoint id to ensure that delivery for the same endpoint happens // in-order. type dispatcher struct { - processors []*processor + processors []processor seed uint32 -} - -func newDispatcher(nProcessors int) *dispatcher { - processors := []*processor{} - for i := 0; i < nProcessors; i++ { - processors = append(processors, newProcessor(i)) - } - return &dispatcher{ - processors: processors, - seed: generateRandUint32(), + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) } } func (d *dispatcher) close() { - for _, p := range d.processors { - p.close() + for i := range d.processors { + d.processors[i].close() } } func (d *dispatcher) wait() { - for _, p := range d.processors { - p.wait() - } + d.wg.Wait() } func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -231,20 +217,18 @@ func generateRandUint32() uint32 { if _, err := rand.Read(b); err != nil { panic(err) } - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return binary.LittleEndian.Uint32(b) } func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { - payload := []byte{ - byte(id.LocalPort), - byte(id.LocalPort >> 8), - byte(id.RemotePort), - byte(id.RemotePort >> 8)} + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) h := jenkins.Sum32(d.seed) - h.Write(payload) + h.Write(payload[:]) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) - return d.processors[h.Sum32()%uint32(len(d.processors))] + return &d.processors[h.Sum32()%uint32(len(d.processors))] } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 10df2bcd5..0f7487963 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -396,7 +396,8 @@ type endpoint struct { mu sync.Mutex `state:"nosave"` ownedByUser uint32 - // state must be read/set using the EndpointState()/setEndpointState() methods. + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -405,8 +406,8 @@ type endpoint struct { origEndpointState EndpointState `state:"nosave"` isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -415,10 +416,14 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags + boundDest tcpip.FullAddress // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -426,7 +431,7 @@ type endpoint struct { // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -462,13 +467,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - - // registeredReusePort is set if the current endpoint registration was - // done with SO_REUSEPORT enabled. - registeredReusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -488,7 +486,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -838,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -851,12 +847,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -871,7 +867,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } - var de tcpip.StackDelayEnabled + var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1025,15 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} } // Mark endpoint as closed. @@ -1091,17 +1087,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1213,9 +1209,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1225,7 +1239,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1235,7 +1248,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1522,12 +1534,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { case tcpip.ReuseAddressOption: e.LockUser() - e.reuseAddr = v + e.portFlags.TupleOnly = v e.UnlockUser() case tcpip.ReusePortOption: e.LockUser() - e.reusePort = v + e.portFlags.LoadBalanced = v e.UnlockUser() case tcpip.V6OnlyOption: @@ -1585,10 +1597,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1638,7 +1657,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1678,7 +1697,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1781,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } @@ -1831,14 +1853,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.ReuseAddressOption: e.LockUser() - v := e.reuseAddr + v := e.portFlags.TupleOnly e.UnlockUser() return v, nil case tcpip.ReusePortOption: e.LockUser() - v := e.reusePort + v := e.portFlags.LoadBalanced e.UnlockUser() return v, nil @@ -1892,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1937,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err + return e.takeLastError() case *tcpip.BindToDeviceOption: e.LockUser() @@ -2091,8 +2114,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -2100,11 +2121,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } - e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2128,40 +2148,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { - case nil: - // Port picking successful. Save the details of - // the selected port. - e.ID = id - e.boundBindToDevice = e.bindToDevice - e.registeredReusePort = e.reusePort - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) - e.isPortReserved = false - } - e.isRegistered = true e.setEndpointState(StateConnecting) e.route = r.Clone() @@ -2340,13 +2353,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) @@ -2433,16 +2445,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - flags := ports.Flags{ - LoadBalanced: e.reusePort, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = flags + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port @@ -2450,7 +2459,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Any failures beyond this point must remove the port registration. defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 @@ -2473,6 +2482,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2537,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } @@ -2609,7 +2634,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2690,7 +2715,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v tcpip.StackSACKEnabled + var v SACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 0bebec2d1..abf1ac5c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() { if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - - if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { - panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") - } } // saveAcceptedChan is invoked by stateify. @@ -186,26 +182,33 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - if len(e.BindAddr) == 0 { - e.BindAddr = e.ID.LocalAddress + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) } - addr := e.BindAddr - port := e.ID.LocalPort - if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { - panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) } switch { @@ -277,17 +280,7 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() case epState == StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.setEndpointState(StateClose) - tcpip.AsyncLoading.Done() - }() - } + e.isPortReserved = false e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 3cff55afa..b34e47bbd 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -76,6 +76,31 @@ const ( ccCubic = "cubic" ) +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -137,8 +162,8 @@ type protocol struct { mu sync.RWMutex sackEnabled bool delayEnabled bool - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -149,7 +174,7 @@ type protocol struct { maxRetries uint32 synRcvdCount synRcvdCounter synRetries uint8 - dispatcher *dispatcher + dispatcher dispatcher } // Number returns the tcp protocol number. @@ -249,19 +274,19 @@ func replyWithReset(s *segment, tos, ttl uint8) { // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case tcpip.StackSACKEnabled: + case SACKEnabled: p.mu.Lock() p.sackEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackDelayEnabled: + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackSendBufferSizeOption: + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -270,7 +295,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case tcpip.StackReceiveBufferSizeOption: + case ReceiveBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -363,25 +388,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *tcpip.StackSACKEnabled: + case *SACKEnabled: p.mu.RLock() - *v = tcpip.StackSACKEnabled(p.sackEnabled) + *v = SACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *tcpip.StackDelayEnabled: + case *DelayEnabled: p.mu.RLock() - *v = tcpip.StackDelayEnabled(p.delayEnabled) + *v = DelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *tcpip.StackSendBufferSizeOption: + case *SendBufferSizeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *tcpip.StackReceiveBufferSizeOption: + case *ReceiveBufferSizeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -490,13 +515,13 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{ + p := protocol{ + sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{ + recvBufferSize: ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -506,10 +531,11 @@ func NewProtocol() stack.TransportProtocol { tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, - dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index dd89a292a..5e0bfe585 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // We only store the segment if it's within our buffer // size limit. if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += s.logicalLen() + r.pendingBufUsed += seqnum.Size(s.segMemSize()) s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= s.logicalLen() + r.pendingBufUsed -= seqnum.Size(s.segMemSize()) s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0280892a8..bb60dc29d 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// segMemSize is the amount of memory used to hold the segment data and +// the associated metadata. +func (s *segment) segMemSize() int { + return segSize + s.data.Size() +} + // parse populates the sequence & ack numbers, flags, and window fields of the // segment from the TCP header stored in the data. It then updates the view to // skip the header. diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go new file mode 100644 index 000000000..0ab7b8f56 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "unsafe" +) + +const ( + segSize = int(unsafe.Sizeof(segment{})) +) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index acacb42e4..5862c32f2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -833,25 +833,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } - segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - // If the entire segment cannot be accomodated in the receiver - // advertized window, skip splitting and sending of the segment. - // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() - // - // Linux checks this for all segment transmits not triggered - // by a probe timer. On this condition, it defers the segment - // split and transmit to a short probe timer. - // ref: include/net/tcp.h::tcp_check_probe_timer() - // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() - // - // Instead of defining a new transmit timer, we attempt to split the - // segment right here if there are no pending segments. - // If there are pending segments, segment transmits are deferred - // to the retransmit timer handler. - if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { - return false - } - if !seg.sequenceNumber.LessThan(end) { return false } @@ -861,14 +842,48 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return false } - // The segment size limit is computed as a function of sender congestion - // window and MSS. When sender congestion window is > 1, this limit can - // be larger than MSS. Ensure that the currently available send space - // is not greater than minimum of this limit and MSS. + // If the whole segment or at least 1MSS sized segment cannot + // be accomodated in the receiver advertized window, skip + // splitting and sending of the segment. ref: + // net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered by + // a probe timer. On this condition, it defers the segment split + // and transmit to a short probe timer. + // + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split + // the segment right here if there are no pending segments. If + // there are pending segments, segment transmits are deferred to + // the retransmit timer handler. + if s.sndUna != s.sndNxt { + switch { + case available >= seg.data.Size(): + // OK to send, the whole segments fits in the + // receiver's advertised window. + case available >= s.maxPayloadSize: + // OK to send, at least 1 MSS sized segment fits + // in the receiver's advertised window. + default: + return false + } + } + + // The segment size limit is computed as a function of sender + // congestion window and MSS. When sender congestion window is > + // 1, this limit can be larger than MSS. Ensure that the + // currently available send space is not greater than minimum of + // this limit and MSS. if available > limit { available = limit } - if available > s.maxPayloadSize { + + // If GSO is not in use then cap available to + // maxPayloadSize. When GSO is in use the gVisor GSO logic or + // the host GSO logic will cap the segment to the correct size. + if s.ep.gso == nil && available > s.maxPayloadSize { available = s.maxPayloadSize } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 812e503bc..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index aca6a7951..e67ec42b1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3095,6 +3095,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3879,6 +3936,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3888,6 +3948,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3898,6 +3961,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3910,6 +3976,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3920,6 +3989,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3932,6 +4004,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3987,7 +4062,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ Min: 1, Default: tcp.DefaultSendBufferSize * 2, Max: tcp.DefaultSendBufferSize * 20}); err != nil { @@ -4004,7 +4079,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ Min: 1, Default: tcp.DefaultReceiveBufferSize * 3, Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { @@ -4035,11 +4110,11 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5678,7 +5753,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5799,7 +5874,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5941,7 +6016,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcpip.StackDelayEnabled + delayEnabled tcp.DelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5956,10 +6031,10 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcpip.StackDelayEnabled + var gotDelayEnabled tcp.DelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9e262c272..37e7767d6 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context { TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, }) + const sendBufferSize = 1 << 20 // 1 MiB + const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context { t: t, s: s, linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), } } @@ -1091,7 +1093,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcpip.StackSACKEnabled + var v tcp.SACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f51988047..b7d735889 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,6 +109,7 @@ type endpoint struct { portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool + noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -190,13 +191,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -231,7 +232,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -482,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - if to.Addr == header.IPv4Broadcast && !e.broadcast { - return 0, nil, tcpip.ErrBroadcastDisabled - } - dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -502,6 +499,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } + if !e.broadcast && route.IsBroadcast() { + return 0, nil, tcpip.ErrBroadcastDisabled + } + if route.IsResolutionRequired() { if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { @@ -529,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -553,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -606,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -629,9 +642,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) } if v < rs.Min { @@ -648,9 +661,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) } if v < ss.Min { @@ -803,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -825,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -894,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -959,7 +985,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -973,8 +999,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Length: length, }) - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -1047,7 +1077,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } e.state = StateInitial @@ -1198,7 +1228,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return id, e.bindToDevice, err } @@ -1208,7 +1238,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err @@ -1350,10 +1380,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - e.rcvMu.Lock() + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } + e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1394,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index fc93f93c0..0e7464e3a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -21,7 +21,6 @@ package udp import ( - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,9 +49,6 @@ const ( ) type protocol struct { - mu sync.RWMutex - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption } // Number returns the udp protocol number. @@ -203,48 +199,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case tcpip.StackSendBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.sendBufferSize = v - p.mu.Unlock() - return nil - - case tcpip.StackReceiveBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.recvBufferSize = v - p.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *tcpip.StackSendBufferSizeOption: - p.mu.RLock() - *v = p.sendBufferSize - p.mu.RUnlock() - return nil - - case *tcpip.StackReceiveBufferSizeOption: - p.mu.RLock() - *v = p.recvBufferSize - p.mu.RUnlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Close implements stack.TransportProtocol.Close. @@ -267,8 +227,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize}, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize}, - } + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 313a3f117..66e8911c8 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio wep = sniffer.New(ep) } if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } s.SetRouteTable([]tcpip.Route{ @@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } else { - c.injectV6Packet(payload, &h, true /* valid */) + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } } -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: l, + Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. @@ -439,17 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -483,11 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } func newPayload() []byte { @@ -509,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() @@ -643,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } } @@ -654,19 +661,19 @@ func TestBindReservedPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } addr, err := c.ep.GetLocalAddress() if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) + t.Fatalf("GetLocalAddress failed: %s", err) } // We can't bind the address reserved by the connected endpoint above. { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { @@ -677,7 +684,7 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint @@ -687,7 +694,7 @@ func TestBindReservedPort(t *testing.T) { } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() @@ -697,11 +704,11 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() } @@ -714,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -729,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { // Bind to v4 mapped wildcard. if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -744,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -759,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -796,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) { h := unicastV4.header4Tuple(incoming) h.srcAddr = h.dstAddr - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -817,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -880,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -955,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... payload := buffer.View(newPayload()) n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { - c.t.Fatalf("Write failed: %v", err) + c.t.Fatalf("Write failed: %s", err) } if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) @@ -1005,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } p := testDualWrite(c) @@ -1022,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV6) @@ -1043,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV4in6) @@ -1070,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Write to v6 address. @@ -1085,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV6) @@ -1099,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV4) @@ -1234,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testRead(c, unicastV4) @@ -1259,6 +1323,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1506,12 +1594,12 @@ func TestMulticastInterfaceOption(t *testing.T) { Port: stackPort, } if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } } if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOpt failed: %s", err) } // Verify multicast interface addr and NIC were set correctly. @@ -1519,7 +1607,7 @@ func TestMulticastInterfaceOption(t *testing.T) { ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) + c.t.Fatalf("GetSockOpt failed: %s", err) } if ifoptGot != ifoptWant { c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) @@ -1691,7 +1779,7 @@ func TestV6UnknownDestination(t *testing.T) { } // TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. +// global and endpoint stats are incremented. func TestIncrementMalformedPacketsReceived(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1699,20 +1787,27 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } payload := newPayload() - c.t.Helper() h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) + buf := c.buildV6Packet(payload, &h) - var want uint64 = 1 + // Invalidate the UDP header length field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) } if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -1728,7 +1823,6 @@ func TestShortHeader(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - c.t.Helper() h := unicastV6.header4Tuple(incoming) // Allocate a buffer for an IPv6 and too-short UDP header. @@ -1768,6 +1862,199 @@ func TestShortHeader(t *testing.T) { } } +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header checksum field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { @@ -1778,15 +2065,15 @@ func TestShutdownRead(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingRead(c, unicastV6, true /* expectReadError */) @@ -1809,11 +2096,11 @@ func TestShutdownWrite(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) @@ -1855,3 +2142,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const nicID1 = 1 + + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + requiresBroadcastOpt bool + }{ + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + requiresBroadcastOpt: true, + }, + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + requiresBroadcastOpt: false, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + requiresBroadcastOpt: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) + } + defer ep.Close() + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + to := tcpip.FullAddress{ + Addr: test.remoteAddr, + Port: 80, + } + opts := tcpip.WriteOptions{To: &to} + expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + if !test.requiresBroadcastOpt { + expectedErrWithoutBcastOpt = nil + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + }) + } +} diff --git a/pkg/test/criutil/criutil.go b/pkg/test/criutil/criutil.go index 8fed29ff5..70945f234 100644 --- a/pkg/test/criutil/criutil.go +++ b/pkg/test/criutil/criutil.go @@ -22,6 +22,9 @@ import ( "fmt" "os" "os/exec" + "path" + "regexp" + "strconv" "strings" "time" @@ -33,28 +36,44 @@ import ( type Crictl struct { logger testutil.Logger endpoint string + runpArgs []string cleanup []func() } -// resolvePath attempts to find binary paths. It may set the path to invalid, +// ResolvePath attempts to find binary paths. It may set the path to invalid, // which will cause the execution to fail with a sensible error. -func resolvePath(executable string) string { +func ResolvePath(executable string) string { + runtime, err := dockerutil.RuntimePath() + if err == nil { + // Check first the directory of the runtime itself. + if dir := path.Dir(runtime); dir != "" && dir != "." { + guess := path.Join(dir, executable) + if fi, err := os.Stat(guess); err == nil && (fi.Mode()&0111) != 0 { + return guess + } + } + } + + // Try to find via the path. guess, err := exec.LookPath(executable) - if err != nil { - guess = fmt.Sprintf("/usr/local/bin/%s", executable) + if err == nil { + return guess } - return guess + + // Return a default path. + return fmt.Sprintf("/usr/local/bin/%s", executable) } // NewCrictl returns a Crictl configured with a timeout and an endpoint over // which it will talk to containerd. -func NewCrictl(logger testutil.Logger, endpoint string) *Crictl { +func NewCrictl(logger testutil.Logger, endpoint string, runpArgs []string) *Crictl { // Attempt to find the executable, but don't bother propagating the // error at this point. The first command executed will return with a // binary not found error. return &Crictl{ logger: logger, endpoint: endpoint, + runpArgs: runpArgs, } } @@ -67,8 +86,8 @@ func (cc *Crictl) CleanUp() { } // RunPod creates a sandbox. It corresponds to `crictl runp`. -func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { - podID, err := cc.run("runp", sbSpecFile) +func (cc *Crictl) RunPod(runtime, sbSpecFile string) (string, error) { + podID, err := cc.run("runp", "--runtime", runtime, sbSpecFile) if err != nil { return "", fmt.Errorf("runp failed: %v", err) } @@ -79,10 +98,42 @@ func (cc *Crictl) RunPod(sbSpecFile string) (string, error) { // Create creates a container within a sandbox. It corresponds to `crictl // create`. func (cc *Crictl) Create(podID, contSpecFile, sbSpecFile string) (string, error) { - podID, err := cc.run("create", podID, contSpecFile, sbSpecFile) + // In version 1.16.0, crictl annoying starting attempting to pull the + // container, even if it was already available locally. We therefore + // need to parse the version and add an appropriate --no-pull argument + // since the image has already been loaded locally. + out, err := cc.run("-v") + if err != nil { + return "", err + } + r := regexp.MustCompile("crictl version ([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(out) + if len(vs) != 4 { + return "", fmt.Errorf("crictl -v had unexpected output: %s", out) + } + major, err := strconv.ParseUint(vs[1], 10, 64) if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + return "", fmt.Errorf("crictl had invalid version: %v (%s)", err, out) + } + + args := []string{"create"} + if (major == 1 && minor >= 16) || major > 1 { + args = append(args, "--no-pull") + } + args = append(args, podID) + args = append(args, contSpecFile) + args = append(args, sbSpecFile) + + podID, err = cc.run(args...) + if err != nil { + time.Sleep(10 * time.Minute) // XXX return "", fmt.Errorf("create failed: %v", err) } + // Strip the trailing newline from crictl output. return strings.TrimSpace(podID), nil } @@ -179,7 +230,7 @@ func (cc *Crictl) Import(image string) error { // be pushing a lot of bytes in order to import the image. The connect // timeout stays the same and is inherited from the Crictl instance. cmd := testutil.Command(cc.logger, - resolvePath("ctr"), + ResolvePath("ctr"), fmt.Sprintf("--connect-timeout=%s", 30*time.Second), fmt.Sprintf("--address=%s", cc.endpoint), "-n", "k8s.io", "images", "import", "-") @@ -260,7 +311,7 @@ func (cc *Crictl) StopContainer(contID string) error { // StartPodAndContainer starts a sandbox and container in that sandbox. It // returns the pod ID and container ID. -func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, string, error) { +func (cc *Crictl) StartPodAndContainer(runtime, image, sbSpec, contSpec string) (string, string, error) { if err := cc.Import(image); err != nil { return "", "", err } @@ -277,7 +328,7 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string, } cc.cleanup = append(cc.cleanup, cleanup) - podID, err := cc.RunPod(sbSpecFile) + podID, err := cc.RunPod(runtime, sbSpecFile) if err != nil { return "", "", err } @@ -307,7 +358,7 @@ func (cc *Crictl) StopPodAndContainer(podID, contID string) error { // run runs crictl with the given args. func (cc *Crictl) run(args ...string) (string, error) { defaultArgs := []string{ - resolvePath("crictl"), + ResolvePath("crictl"), "--image-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), "--runtime-endpoint", fmt.Sprintf("unix://%s", cc.endpoint), } diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD index 7c8758e35..a5e84658a 100644 --- a/pkg/test/dockerutil/BUILD +++ b/pkg/test/dockerutil/BUILD @@ -1,14 +1,42 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "dockerutil", testonly = 1, - srcs = ["dockerutil.go"], + srcs = [ + "container.go", + "dockerutil.go", + "exec.go", + "network.go", + "profile.go", + ], visibility = ["//:sandbox"], deps = [ "//pkg/test/testutil", - "@com_github_kr_pty//:go_default_library", + "@com_github_docker_docker//api/types:go_default_library", + "@com_github_docker_docker//api/types/container:go_default_library", + "@com_github_docker_docker//api/types/mount:go_default_library", + "@com_github_docker_docker//api/types/network:go_default_library", + "@com_github_docker_docker//client:go_default_library", + "@com_github_docker_docker//pkg/stdcopy:go_default_library", + "@com_github_docker_go_connections//nat:go_default_library", + ], +) + +go_test( + name = "profile_test", + size = "large", + srcs = [ + "profile_test.go", + ], + library = ":dockerutil", + tags = [ + # Requires docker and runsc to be configured before test runs. + # Also requires the test to be run as root. + "manual", + "local", ], + visibility = ["//:sandbox"], ) diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md new file mode 100644 index 000000000..870292096 --- /dev/null +++ b/pkg/test/dockerutil/README.md @@ -0,0 +1,86 @@ +# dockerutil + +This package is for creating and controlling docker containers for testing +runsc, gVisor's docker/kubernetes binary. A simple test may look like: + +``` + func TestSuperCool(t *testing.T) { + ctx := context.Background() + c := dockerutil.MakeContainer(ctx, t) + got, err := c.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine" + }, "echo", "super cool") + if err != nil { + t.Fatalf("err was not nil: %v", err) + } + want := "super cool" + if !strings.Contains(got, want){ + t.Fatalf("want: %s, got: %s", want, got) + } + } +``` + +For further examples, see many of our end to end tests elsewhere in the repo, +such as those in //test/e2e or benchmarks at //test/benchmarks. + +dockerutil uses the "official" docker golang api, which is +[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil +is a thin wrapper around this API, allowing desired new use cases to be easily +implemented. + +## Profiling + +dockerutil is capable of generating profiles. Currently, the only option is to +use pprof profiles generated by `runsc debug`. The profiler will generate Block, +CPU, Heap, Goroutine, and Mutex profiles. To generate profiles: + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or + `--vfs2`. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles run: + +``` +make sudo TARGETS=//path/to:target \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` + +Container name in most tests and benchmarks in gVisor is usually the test name +and some random characters like so: +`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2` + +Profiling requires root as runsc debug inspects running containers in /var/run +among other things. + +### Writing for Profiling + +The below shows an example of using profiles with dockerutil. + +``` +func TestSuperCool(t *testing.T){ + ctx := context.Background() + // profiled and using runtime from dockerutil.runtime flag + profiled := MakeContainer() + + // not profiled and using runtime runc + native := MakeNativeContainer() + + err := profiled.Spawn(ctx, RunOpts{ + Image: "some/image", + }, "sleep", "100000") + // profiling has begun here + ... + expensive setup that I don't want to profile. + ... + profiled.RestartProfiles() + // profiled activity +} +``` + +In the above example, `profiled` would be profiled and `native` would not. The +call to `RestartProfiles()` restarts the clock on profiling. This is useful if +the main activity being tested is done with `docker exec` or `container.Spawn()` +followed by one or more `container.Exec()` calls. diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go new file mode 100644 index 000000000..5a2157951 --- /dev/null +++ b/pkg/test/dockerutil/container.go @@ -0,0 +1,558 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net" + "os" + "path" + "regexp" + "strconv" + "strings" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/docker/go-connections/nat" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Container represents a Docker Container allowing +// user to configure and control as one would with the 'docker' +// client. Container is backed by the offical golang docker API. +// See: https://pkg.go.dev/github.com/docker/docker. +type Container struct { + Name string + runtime string + + logger testutil.Logger + client *client.Client + id string + mounts []mount.Mount + links []string + copyErr error + cleanups []func() + + // Profiles are profiles added to this container. They contain methods + // that are run after Creation, Start, and Cleanup of this Container, along + // a handle to restart the profile. Generally, tests/benchmarks using + // profiles need to run as root. + profiles []Profile + + // Stores streams attached to the container. Used by WaitForOutputSubmatch. + streams types.HijackedResponse + + // stores previously read data from the attached streams. + streamBuf bytes.Buffer +} + +// RunOpts are options for running a container. +type RunOpts struct { + // Image is the image relative to images/. This will be mangled + // appropriately, to ensure that only first-party images are used. + Image string + + // Memory is the memory limit in bytes. + Memory int + + // Cpus in which to allow execution. ("0", "1", "0-2"). + CpusetCpus string + + // Ports are the ports to be allocated. + Ports []int + + // WorkDir sets the working directory. + WorkDir string + + // ReadOnly sets the read-only flag. + ReadOnly bool + + // Env are additional environment variables. + Env []string + + // User is the user to use. + User string + + // Privileged enables privileged mode. + Privileged bool + + // CapAdd are the extra set of capabilities to add. + CapAdd []string + + // CapDrop are the extra set of capabilities to drop. + CapDrop []string + + // Mounts is the list of directories/files to be mounted inside the container. + Mounts []mount.Mount + + // Links is the list of containers to be connected to the container. + Links []string +} + +// MakeContainer sets up the struct for a Docker container. +// +// Names of containers will be unique. +// Containers will check flags for profiling requests. +func MakeContainer(ctx context.Context, logger testutil.Logger) *Container { + c := MakeNativeContainer(ctx, logger) + c.runtime = *runtime + if p := MakePprofFromFlags(c); p != nil { + c.AddProfile(p) + } + return c +} + +// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native +// containers aren't profiled. +func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container { + // Slashes are not allowed in container names. + name := testutil.RandomID(logger.Name()) + name = strings.ReplaceAll(name, "/", "-") + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return nil + } + client.NegotiateAPIVersion(ctx) + return &Container{ + logger: logger, + Name: name, + runtime: "", + client: client, + } +} + +// AddProfile adds a profile to this container. +func (c *Container) AddProfile(p Profile) { + c.profiles = append(c.profiles, p) +} + +// RestartProfiles calls Restart on all profiles for this container. +func (c *Container) RestartProfiles() error { + for _, profile := range c.profiles { + if err := profile.Restart(c); err != nil { + return err + } + } + return nil +} + +// Spawn is analogous to 'docker run -d'. +func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return err + } + return c.Start(ctx) +} + +// SpawnProcess is analogous to 'docker run -it'. It returns a process +// which represents the root process. +func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) (Process, error) { + config, hostconf, netconf := c.ConfigsFrom(r, args...) + config.Tty = true + config.OpenStdin = true + + if err := c.CreateFrom(ctx, config, hostconf, netconf); err != nil { + return Process{}, err + } + + if err := c.Start(ctx); err != nil { + return Process{}, err + } + + return Process{container: c, conn: c.streams}, nil +} + +// Run is analogous to 'docker run'. +func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) { + if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil { + return "", err + } + + if err := c.Start(ctx); err != nil { + return "", err + } + + if err := c.Wait(ctx); err != nil { + return "", err + } + + return c.Logs(ctx) +} + +// ConfigsFrom returns container configs from RunOpts and args. The caller should call 'CreateFrom' +// and Start. +func (c *Container) ConfigsFrom(r RunOpts, args ...string) (*container.Config, *container.HostConfig, *network.NetworkingConfig) { + return c.config(r, args), c.hostConfig(r), &network.NetworkingConfig{} +} + +// MakeLink formats a link to add to a RunOpts. +func (c *Container) MakeLink(target string) string { + return fmt.Sprintf("%s:%s", c.Name, target) +} + +// CreateFrom creates a container from the given configs. +func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + return c.create(ctx, conf, hostconf, netconf) +} + +// Create is analogous to 'docker create'. +func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error { + return c.create(ctx, c.config(r, args), c.hostConfig(r), nil) +} + +func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error { + cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name) + if err != nil { + return err + } + c.id = cont.ID + for _, profile := range c.profiles { + if err := profile.OnCreate(c); err != nil { + return fmt.Errorf("OnCreate method failed with: %v", err) + } + } + return nil +} + +func (c *Container) config(r RunOpts, args []string) *container.Config { + ports := nat.PortSet{} + for _, p := range r.Ports { + port := nat.Port(fmt.Sprintf("%d", p)) + ports[port] = struct{}{} + } + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + + return &container.Config{ + Image: testutil.ImageByName(r.Image), + Cmd: args, + ExposedPorts: ports, + Env: env, + WorkingDir: r.WorkDir, + User: r.User, + } +} + +func (c *Container) hostConfig(r RunOpts) *container.HostConfig { + c.mounts = append(c.mounts, r.Mounts...) + + return &container.HostConfig{ + Runtime: c.runtime, + Mounts: c.mounts, + PublishAllPorts: true, + Links: r.Links, + CapAdd: r.CapAdd, + CapDrop: r.CapDrop, + Privileged: r.Privileged, + ReadonlyRootfs: r.ReadOnly, + Resources: container.Resources{ + Memory: int64(r.Memory), // In bytes. + CpusetCpus: r.CpusetCpus, + }, + } +} + +// Start is analogous to 'docker start'. +func (c *Container) Start(ctx context.Context) error { + + // Open a connection to the container for parsing logs and for TTY. + streams, err := c.client.ContainerAttach(ctx, c.id, + types.ContainerAttachOptions{ + Stream: true, + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + return fmt.Errorf("failed to connect to container: %v", err) + } + + c.streams = streams + c.cleanups = append(c.cleanups, func() { + c.streams.Close() + }) + if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { + return fmt.Errorf("ContainerStart failed: %v", err) + } + for _, profile := range c.profiles { + if err := profile.OnStart(c); err != nil { + return fmt.Errorf("OnStart method failed: %v", err) + } + } + return nil +} + +// Stop is analogous to 'docker stop'. +func (c *Container) Stop(ctx context.Context) error { + return c.client.ContainerStop(ctx, c.id, nil) +} + +// Pause is analogous to'docker pause'. +func (c *Container) Pause(ctx context.Context) error { + return c.client.ContainerPause(ctx, c.id) +} + +// Unpause is analogous to 'docker unpause'. +func (c *Container) Unpause(ctx context.Context) error { + return c.client.ContainerUnpause(ctx, c.id) +} + +// Checkpoint is analogous to 'docker checkpoint'. +func (c *Container) Checkpoint(ctx context.Context, name string) error { + return c.client.CheckpointCreate(ctx, c.Name, types.CheckpointCreateOptions{CheckpointID: name, Exit: true}) +} + +// Restore is analogous to 'docker start --checkname [name]'. +func (c *Container) Restore(ctx context.Context, name string) error { + return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{CheckpointID: name}) +} + +// Logs is analogous 'docker logs'. +func (c *Container) Logs(ctx context.Context) (string, error) { + var out bytes.Buffer + err := c.logs(ctx, &out, &out) + return out.String(), err +} + +func (c *Container) logs(ctx context.Context, stdout, stderr *bytes.Buffer) error { + opts := types.ContainerLogsOptions{ShowStdout: true, ShowStderr: true} + writer, err := c.client.ContainerLogs(ctx, c.id, opts) + if err != nil { + return err + } + defer writer.Close() + _, err = stdcopy.StdCopy(stdout, stderr, writer) + + return err +} + +// ID returns the container id. +func (c *Container) ID() string { + return c.id +} + +// SandboxPid returns the container's pid. +func (c *Container) SandboxPid(ctx context.Context) (int, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, err + } + return resp.ContainerJSONBase.State.Pid, nil +} + +// FindIP returns the IP address of the container. +func (c *Container) FindIP(ctx context.Context, ipv6 bool) (net.IP, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return nil, err + } + + var ip net.IP + if ipv6 { + ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.GlobalIPv6Address) + } else { + ip = net.ParseIP(resp.NetworkSettings.DefaultNetworkSettings.IPAddress) + } + if ip == nil { + return net.IP{}, fmt.Errorf("invalid IP: %q", ip) + } + return ip, nil +} + +// FindPort returns the host port that is mapped to 'sandboxPort'. +func (c *Container) FindPort(ctx context.Context, sandboxPort int) (int, error) { + desc, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return -1, fmt.Errorf("error retrieving port: %v", err) + } + + format := fmt.Sprintf("%d/tcp", sandboxPort) + ports, ok := desc.NetworkSettings.Ports[nat.Port(format)] + if !ok { + return -1, fmt.Errorf("error retrieving port: %v", err) + + } + + port, err := strconv.Atoi(ports[0].HostPort) + if err != nil { + return -1, fmt.Errorf("error parsing port %q: %v", port, err) + } + return port, nil +} + +// CopyFiles copies in and mounts the given files. They are always ReadOnly. +func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { + dir, err := ioutil.TempDir("", c.Name) + if err != nil { + c.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) + return + } + c.cleanups = append(c.cleanups, func() { os.RemoveAll(dir) }) + if err := os.Chmod(dir, 0755); err != nil { + c.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) + return + } + for _, name := range sources { + src, err := testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) + return + } + dst := path.Join(dir, path.Base(name)) + if err := testutil.Copy(src, dst); err != nil { + c.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) + return + } + c.logger.Logf("copy: %s -> %s", src, dst) + } + opts.Mounts = append(opts.Mounts, mount.Mount{ + Type: mount.TypeBind, + Source: dir, + Target: target, + ReadOnly: false, + }) +} + +// Status inspects the container returns its status. +func (c *Container) Status(ctx context.Context) (types.ContainerState, error) { + resp, err := c.client.ContainerInspect(ctx, c.id) + if err != nil { + return types.ContainerState{}, err + } + return *resp.State, err +} + +// Wait waits for the container to exit. +func (c *Container) Wait(ctx context.Context) error { + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitTimeout waits for the container to exit with a timeout. +func (c *Container) WaitTimeout(ctx context.Context, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + statusChan, errChan := c.client.ContainerWait(ctx, c.id, container.WaitConditionNotRunning) + select { + case <-ctx.Done(): + if ctx.Err() == context.DeadlineExceeded { + return fmt.Errorf("container %s timed out after %v seconds", c.Name, timeout.Seconds()) + } + return nil + case err := <-errChan: + return err + case <-statusChan: + return nil + } +} + +// WaitForOutput searches container logs for pattern and returns or timesout. +func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout time.Duration) (string, error) { + matches, err := c.WaitForOutputSubmatch(ctx, pattern, timeout) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", fmt.Errorf("didn't find pattern %s logs", pattern) + } + return matches[0], nil +} + +// WaitForOutputSubmatch searches container logs for the given +// pattern or times out. It returns any regexp submatches as well. +func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { + re := regexp.MustCompile(pattern) + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + + for exp := time.Now().Add(timeout); time.Now().Before(exp); { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) + _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader) + + if err != nil { + // check that it wasn't a timeout + if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() { + return nil, err + } + } + + if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + return matches, nil + } + } + + return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String()) +} + +// Kill kills the container. +func (c *Container) Kill(ctx context.Context) error { + return c.client.ContainerKill(ctx, c.id, "") +} + +// Remove is analogous to 'docker rm'. +func (c *Container) Remove(ctx context.Context) error { + // Remove the image. + remove := types.ContainerRemoveOptions{ + RemoveVolumes: c.mounts != nil, + RemoveLinks: c.links != nil, + Force: true, + } + return c.client.ContainerRemove(ctx, c.Name, remove) +} + +// CleanUp kills and deletes the container (best effort). +func (c *Container) CleanUp(ctx context.Context) { + // Execute profile cleanups before the container goes down. + for _, profile := range c.profiles { + profile.OnCleanUp(c) + } + // Forget profiles. + c.profiles = nil + // Kill the container. + if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { + // Just log; can't do anything here. + c.logger.Logf("error killing container %q: %v", c.Name, err) + } + // Remove the image. + if err := c.Remove(ctx); err != nil { + c.logger.Logf("error removing container %q: %v", c.Name, err) + } + // Forget all mounts. + c.mounts = nil + // Execute all cleanups. + for _, c := range c.cleanups { + c() + } + c.cleanups = nil +} diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 819dd0a59..5a9dd8bd8 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -22,17 +22,11 @@ import ( "io" "io/ioutil" "log" - "net" - "os" "os/exec" - "path" "regexp" "strconv" - "strings" - "syscall" "time" - "github.com/kr/pty" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -49,6 +43,26 @@ var ( // config is the default Docker daemon configuration path. config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths") + + // The following flags are for the "pprof" profiler tool. + + // pprofBaseDir allows the user to change the directory to which profiles are + // written. By default, profiles will appear under: + // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof. + pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)") + + // duration is the max duration `runsc debug` will run and capture profiles. + // If the container's clean up method is called prior to duration, the + // profiling process will be killed. + duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds") + + // The below flags enable each type of profile. Multiple profiles can be + // enabled for each run. + pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") + pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") + pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug") + pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") + pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) // EnsureSupportedDockerVersion checks if correct docker is installed. @@ -127,595 +141,7 @@ func Save(logger testutil.Logger, image string, w io.Writer) error { return cmd.Run() } -// MountMode describes if the mount should be ro or rw. -type MountMode int - -const ( - // ReadOnly is what the name says. - ReadOnly MountMode = iota - // ReadWrite is what the name says. - ReadWrite -) - -// String returns the mount mode argument for this MountMode. -func (m MountMode) String() string { - switch m { - case ReadOnly: - return "ro" - case ReadWrite: - return "rw" - } - panic(fmt.Sprintf("invalid mode: %d", m)) -} - -// DockerNetwork contains the name of a docker network. -type DockerNetwork struct { - logger testutil.Logger - Name string - Subnet *net.IPNet - containers []*Docker -} - -// NewDockerNetwork sets up the struct for a Docker network. Names of networks -// will be unique. -func NewDockerNetwork(logger testutil.Logger) *DockerNetwork { - return &DockerNetwork{ - logger: logger, - Name: testutil.RandomID(logger.Name()), - } -} - -// Create calls 'docker network create'. -func (n *DockerNetwork) Create(args ...string) error { - a := []string{"docker", "network", "create"} - if n.Subnet != nil { - a = append(a, fmt.Sprintf("--subnet=%s", n.Subnet)) - } - a = append(a, args...) - a = append(a, n.Name) - return testutil.Command(n.logger, a...).Run() -} - -// Connect calls 'docker network connect' with the arguments provided. -func (n *DockerNetwork) Connect(container *Docker, args ...string) error { - a := []string{"docker", "network", "connect"} - a = append(a, args...) - a = append(a, n.Name, container.Name) - if err := testutil.Command(n.logger, a...).Run(); err != nil { - return err - } - n.containers = append(n.containers, container) - return nil -} - -// Cleanup cleans up the docker network and all the containers attached to it. -func (n *DockerNetwork) Cleanup() error { - for _, c := range n.containers { - // Don't propagate the error, it might be that the container - // was already cleaned up. - if err := c.Kill(); err != nil { - n.logger.Logf("unable to kill container during cleanup: %s", err) - } - } - - if err := testutil.Command(n.logger, "docker", "network", "rm", n.Name).Run(); err != nil { - return err - } - return nil -} - -// Docker contains the name and the runtime of a docker container. -type Docker struct { - logger testutil.Logger - Runtime string - Name string - copyErr error - cleanups []func() -} - -// MakeDocker sets up the struct for a Docker container. -// -// Names of containers will be unique. -func MakeDocker(logger testutil.Logger) *Docker { - // Slashes are not allowed in container names. - name := testutil.RandomID(logger.Name()) - name = strings.ReplaceAll(name, "/", "-") - - return &Docker{ - logger: logger, - Name: name, - Runtime: *runtime, - } -} - -// CopyFiles copies in and mounts the given files. They are always ReadOnly. -func (d *Docker) CopyFiles(opts *RunOpts, targetDir string, sources ...string) { - dir, err := ioutil.TempDir("", d.Name) - if err != nil { - d.copyErr = fmt.Errorf("ioutil.TempDir failed: %v", err) - return - } - d.cleanups = append(d.cleanups, func() { os.RemoveAll(dir) }) - if err := os.Chmod(dir, 0755); err != nil { - d.copyErr = fmt.Errorf("os.Chmod(%q, 0755) failed: %v", dir, err) - return - } - for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - d.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return - } - dst := path.Join(dir, path.Base(name)) - if err := testutil.Copy(src, dst); err != nil { - d.copyErr = fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err) - return - } - d.logger.Logf("copy: %s -> %s", src, dst) - } - opts.Mounts = append(opts.Mounts, Mount{ - Source: dir, - Target: targetDir, - Mode: ReadOnly, - }) -} - -// Mount describes a mount point inside the container. -type Mount struct { - // Source is the path outside the container. - Source string - - // Target is the path inside the container. - Target string - - // Mode tells whether the mount inside the container should be readonly. - Mode MountMode -} - -// Link informs dockers that a given container needs to be made accessible from -// the container being configured. -type Link struct { - // Source is the container to connect to. - Source *Docker - - // Target is the alias for the container. - Target string -} - -// RunOpts are options for running a container. -type RunOpts struct { - // Image is the image relative to images/. This will be mangled - // appropriately, to ensure that only first-party images are used. - Image string - - // Memory is the memory limit in kB. - Memory int - - // Ports are the ports to be allocated. - Ports []int - - // WorkDir sets the working directory. - WorkDir string - - // ReadOnly sets the read-only flag. - ReadOnly bool - - // Env are additional environment variables. - Env []string - - // User is the user to use. - User string - - // Privileged enables privileged mode. - Privileged bool - - // CapAdd are the extra set of capabilities to add. - CapAdd []string - - // CapDrop are the extra set of capabilities to drop. - CapDrop []string - - // Pty indicates that a pty will be allocated. If this is non-nil, then - // this will run after start-up with the *exec.Command and Pty file - // passed in to the function. - Pty func(*exec.Cmd, *os.File) - - // Foreground indicates that the container should be run in the - // foreground. If this is true, then the output will be available as a - // return value from the Run function. - Foreground bool - - // Mounts is the list of directories/files to be mounted inside the container. - Mounts []Mount - - // Links is the list of containers to be connected to the container. - Links []Link - - // Extra are extra arguments that may be passed. - Extra []string -} - -// args returns common arguments. -// -// Note that this does not define the complete behavior. -func (d *Docker) argsFor(r *RunOpts, command string, p []string) (rv []string) { - isExec := command == "exec" - isRun := command == "run" - - if isRun || isExec { - rv = append(rv, "-i") - } - if r.Pty != nil { - rv = append(rv, "-t") - } - if r.User != "" { - rv = append(rv, fmt.Sprintf("--user=%s", r.User)) - } - if r.Privileged { - rv = append(rv, "--privileged") - } - for _, c := range r.CapAdd { - rv = append(rv, fmt.Sprintf("--cap-add=%s", c)) - } - for _, c := range r.CapDrop { - rv = append(rv, fmt.Sprintf("--cap-drop=%s", c)) - } - for _, e := range r.Env { - rv = append(rv, fmt.Sprintf("--env=%s", e)) - } - if r.WorkDir != "" { - rv = append(rv, fmt.Sprintf("--workdir=%s", r.WorkDir)) - } - if !isExec { - if r.Memory != 0 { - rv = append(rv, fmt.Sprintf("--memory=%dk", r.Memory)) - } - for _, p := range r.Ports { - rv = append(rv, fmt.Sprintf("--publish=%d", p)) - } - if r.ReadOnly { - rv = append(rv, fmt.Sprintf("--read-only")) - } - if len(p) > 0 { - rv = append(rv, "--entrypoint=") - } - } - - // Always attach the test environment & Extra. - rv = append(rv, fmt.Sprintf("--env=RUNSC_TEST_NAME=%s", d.Name)) - rv = append(rv, r.Extra...) - - // Attach necessary bits. - if isExec { - rv = append(rv, d.Name) - } else { - for _, m := range r.Mounts { - rv = append(rv, fmt.Sprintf("-v=%s:%s:%v", m.Source, m.Target, m.Mode)) - } - for _, l := range r.Links { - rv = append(rv, fmt.Sprintf("--link=%s:%s", l.Source.Name, l.Target)) - } - - if len(d.Runtime) > 0 { - rv = append(rv, fmt.Sprintf("--runtime=%s", d.Runtime)) - } - rv = append(rv, fmt.Sprintf("--name=%s", d.Name)) - rv = append(rv, testutil.ImageByName(r.Image)) - } - - // Attach other arguments. - rv = append(rv, p...) - return rv -} - -// run runs a complete command. -func (d *Docker) run(r RunOpts, command string, p ...string) (string, error) { - if d.copyErr != nil { - return "", d.copyErr - } - basicArgs := []string{"docker"} - if command == "spawn" { - command = "run" - basicArgs = append(basicArgs, command) - basicArgs = append(basicArgs, "-d") - } else { - basicArgs = append(basicArgs, command) - } - customArgs := d.argsFor(&r, command, p) - cmd := testutil.Command(d.logger, append(basicArgs, customArgs...)...) - if r.Pty != nil { - // If allocating a terminal, then we just ignore the output - // from the command. - ptmx, err := pty.Start(cmd.Cmd) - if err != nil { - return "", err - } - defer cmd.Wait() // Best effort. - r.Pty(cmd.Cmd, ptmx) - } else { - // Can't support PTY or streaming. - out, err := cmd.CombinedOutput() - return string(out), err - } - return "", nil -} - -// Create calls 'docker create' with the arguments provided. -func (d *Docker) Create(r RunOpts, args ...string) error { - out, err := d.run(r, "create", args...) - if strings.Contains(out, "Unable to find image") { - return fmt.Errorf("unable to find image, did you remember to `make load-%s`: %w", r.Image, err) - } - return err -} - -// Start calls 'docker start'. -func (d *Docker) Start() error { - return testutil.Command(d.logger, "docker", "start", d.Name).Run() -} - -// Stop calls 'docker stop'. -func (d *Docker) Stop() error { - return testutil.Command(d.logger, "docker", "stop", d.Name).Run() -} - -// Run calls 'docker run' with the arguments provided. -func (d *Docker) Run(r RunOpts, args ...string) (string, error) { - return d.run(r, "run", args...) -} - -// Spawn starts the container and detaches. -func (d *Docker) Spawn(r RunOpts, args ...string) error { - _, err := d.run(r, "spawn", args...) - return err -} - -// Logs calls 'docker logs'. -func (d *Docker) Logs() (string, error) { - // Don't capture the output; since it will swamp the logs. - out, err := exec.Command("docker", "logs", d.Name).CombinedOutput() - return string(out), err -} - -// Exec calls 'docker exec' with the arguments provided. -func (d *Docker) Exec(r RunOpts, args ...string) (string, error) { - return d.run(r, "exec", args...) -} - -// Pause calls 'docker pause'. -func (d *Docker) Pause() error { - return testutil.Command(d.logger, "docker", "pause", d.Name).Run() -} - -// Unpause calls 'docker pause'. -func (d *Docker) Unpause() error { - return testutil.Command(d.logger, "docker", "unpause", d.Name).Run() -} - -// Checkpoint calls 'docker checkpoint'. -func (d *Docker) Checkpoint(name string) error { - return testutil.Command(d.logger, "docker", "checkpoint", "create", d.Name, name).Run() -} - -// Restore calls 'docker start --checkname [name]'. -func (d *Docker) Restore(name string) error { - return testutil.Command(d.logger, "docker", "start", fmt.Sprintf("--checkpoint=%s", name), d.Name).Run() -} - -// Kill calls 'docker kill'. -func (d *Docker) Kill() error { - // Skip logging this command, it will likely be an error. - out, err := exec.Command("docker", "kill", d.Name).CombinedOutput() - if err != nil && !strings.Contains(string(out), "is not running") { - return err - } - return nil -} - -// Remove calls 'docker rm'. -func (d *Docker) Remove() error { - return testutil.Command(d.logger, "docker", "rm", d.Name).Run() -} - -// CleanUp kills and deletes the container (best effort). -func (d *Docker) CleanUp() { - // Kill the container. - if err := d.Kill(); err != nil { - // Just log; can't do anything here. - d.logger.Logf("error killing container %q: %v", d.Name, err) - } - // Remove the image. - if err := d.Remove(); err != nil { - d.logger.Logf("error removing container %q: %v", d.Name, err) - } - // Execute all cleanups. - for _, c := range d.cleanups { - c() - } - d.cleanups = nil -} - -// FindPort returns the host port that is mapped to 'sandboxPort'. This calls -// docker to allocate a free port in the host and prevent conflicts. -func (d *Docker) FindPort(sandboxPort int) (int, error) { - format := fmt.Sprintf(`{{ (index (index .NetworkSettings.Ports "%d/tcp") 0).HostPort }}`, sandboxPort) - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving port: %v", err) - } - port, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing port %q: %v", out, err) - } - return port, nil -} - -// FindIP returns the IP address of the container. -func (d *Docker) FindIP() (net.IP, error) { - const format = `{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return net.IP{}, fmt.Errorf("error retrieving IP: %v", err) - } - ip := net.ParseIP(strings.TrimSpace(string(out))) - if ip == nil { - return net.IP{}, fmt.Errorf("invalid IP: %q", string(out)) - } - return ip, nil -} - -// A NetworkInterface is container's network interface information. -type NetworkInterface struct { - IPv4 net.IP - MAC net.HardwareAddr -} - -// ListNetworks returns the network interfaces of the container, keyed by -// Docker network name. -func (d *Docker) ListNetworks() (map[string]NetworkInterface, error) { - const format = `{{json .NetworkSettings.Networks}}` - out, err := testutil.Command(d.logger, "docker", "inspect", "-f", format, d.Name).CombinedOutput() - if err != nil { - return nil, fmt.Errorf("error network interfaces: %q: %w", string(out), err) - } - - networks := map[string]map[string]string{} - if err := json.Unmarshal(out, &networks); err != nil { - return nil, fmt.Errorf("error decoding network interfaces: %w", err) - } - - interfaces := map[string]NetworkInterface{} - for name, iface := range networks { - var netface NetworkInterface - - rawIP := strings.TrimSpace(iface["IPAddress"]) - if rawIP != "" { - ip := net.ParseIP(rawIP) - if ip == nil { - return nil, fmt.Errorf("invalid IP: %q", rawIP) - } - // Docker's IPAddress field is IPv4. The IPv6 address - // is stored in the GlobalIPv6Address field. - netface.IPv4 = ip - } - - rawMAC := strings.TrimSpace(iface["MacAddress"]) - if rawMAC != "" { - mac, err := net.ParseMAC(rawMAC) - if err != nil { - return nil, fmt.Errorf("invalid MAC: %q: %w", rawMAC, err) - } - netface.MAC = mac - } - - interfaces[name] = netface - } - - return interfaces, nil -} - -// SandboxPid returns the PID to the sandbox process. -func (d *Docker) SandboxPid() (int, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.State.Pid}}", d.Name).CombinedOutput() - if err != nil { - return -1, fmt.Errorf("error retrieving pid: %v", err) - } - pid, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - return -1, fmt.Errorf("error parsing pid %q: %v", out, err) - } - return pid, nil -} - -// ID returns the container ID. -func (d *Docker) ID() (string, error) { - out, err := testutil.Command(d.logger, "docker", "inspect", "-f={{.Id}}", d.Name).CombinedOutput() - if err != nil { - return "", fmt.Errorf("error retrieving ID: %v", err) - } - return strings.TrimSpace(string(out)), nil -} - -// Wait waits for container to exit, up to the given timeout. Returns error if -// wait fails or timeout is hit. Returns the application return code otherwise. -// Note that the application may have failed even if err == nil, always check -// the exit code. -func (d *Docker) Wait(timeout time.Duration) (syscall.WaitStatus, error) { - timeoutChan := time.After(timeout) - waitChan := make(chan (syscall.WaitStatus)) - errChan := make(chan (error)) - - go func() { - out, err := testutil.Command(d.logger, "docker", "wait", d.Name).CombinedOutput() - if err != nil { - errChan <- fmt.Errorf("error waiting for container %q: %v", d.Name, err) - } - exit, err := strconv.Atoi(strings.TrimSuffix(string(out), "\n")) - if err != nil { - errChan <- fmt.Errorf("error parsing exit code %q: %v", out, err) - } - waitChan <- syscall.WaitStatus(uint32(exit)) - }() - - select { - case ws := <-waitChan: - return ws, nil - case err := <-errChan: - return syscall.WaitStatus(1), err - case <-timeoutChan: - return syscall.WaitStatus(1), fmt.Errorf("timeout waiting for container %q", d.Name) - } -} - -// WaitForOutput calls 'docker logs' to retrieve containers output and searches -// for the given pattern. -func (d *Docker) WaitForOutput(pattern string, timeout time.Duration) (string, error) { - matches, err := d.WaitForOutputSubmatch(pattern, timeout) - if err != nil { - return "", err - } - if len(matches) == 0 { - return "", nil - } - return matches[0], nil -} - -// WaitForOutputSubmatch calls 'docker logs' to retrieve containers output and -// searches for the given pattern. It returns any regexp submatches as well. -func (d *Docker) WaitForOutputSubmatch(pattern string, timeout time.Duration) ([]string, error) { - re := regexp.MustCompile(pattern) - var ( - lastOut string - stopped bool - ) - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - out, err := d.Logs() - if err != nil { - return nil, err - } - if out != lastOut { - if lastOut == "" { - d.logger.Logf("output (start): %s", out) - } else if strings.HasPrefix(out, lastOut) { - d.logger.Logf("output (contn): %s", out[len(lastOut):]) - } else { - d.logger.Logf("output (trunc): %s", out) - } - lastOut = out // Save for future. - if matches := re.FindStringSubmatch(lastOut); matches != nil { - return matches, nil // Success! - } - } else if stopped { - // The sandbox stopped and we looked at the - // logs at least once since determining that. - return nil, fmt.Errorf("no longer running: %v", err) - } else if pid, err := d.SandboxPid(); pid == 0 || err != nil { - // The sandbox may have stopped, but it's - // possible that it has emitted the terminal - // line between the last call to Logs and here. - stopped = true - } - time.Sleep(100 * time.Millisecond) - } - return nil, fmt.Errorf("timeout waiting for output %q: %s", re.String(), lastOut) +// Runtime returns the value of the flag runtime. +func Runtime() string { + return *runtime } diff --git a/pkg/test/dockerutil/exec.go b/pkg/test/dockerutil/exec.go new file mode 100644 index 000000000..4c739c9e9 --- /dev/null +++ b/pkg/test/dockerutil/exec.go @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/pkg/stdcopy" +) + +// ExecOpts holds arguments for Exec calls. +type ExecOpts struct { + // Env are additional environment variables. + Env []string + + // Privileged enables privileged mode. + Privileged bool + + // User is the user to use. + User string + + // Enables Tty and stdin for the created process. + UseTTY bool + + // WorkDir is the working directory of the process. + WorkDir string +} + +// Exec creates a process inside the container. +func (c *Container) Exec(ctx context.Context, opts ExecOpts, args ...string) (string, error) { + p, err := c.doExec(ctx, opts, args) + if err != nil { + return "", err + } + + if exitStatus, err := p.WaitExitStatus(ctx); err != nil { + return "", err + } else if exitStatus != 0 { + out, _ := p.Logs() + return out, fmt.Errorf("process terminated with status: %d", exitStatus) + } + + return p.Logs() +} + +// ExecProcess creates a process inside the container and returns a process struct +// for the caller to use. +func (c *Container) ExecProcess(ctx context.Context, opts ExecOpts, args ...string) (Process, error) { + return c.doExec(ctx, opts, args) +} + +func (c *Container) doExec(ctx context.Context, r ExecOpts, args []string) (Process, error) { + config := c.execConfig(r, args) + resp, err := c.client.ContainerExecCreate(ctx, c.id, config) + if err != nil { + return Process{}, fmt.Errorf("exec create failed with err: %v", err) + } + + hijack, err := c.client.ContainerExecAttach(ctx, resp.ID, types.ExecStartCheck{}) + if err != nil { + return Process{}, fmt.Errorf("exec attach failed with err: %v", err) + } + + if err := c.client.ContainerExecStart(ctx, resp.ID, types.ExecStartCheck{}); err != nil { + hijack.Close() + return Process{}, fmt.Errorf("exec start failed with err: %v", err) + } + + return Process{ + container: c, + execid: resp.ID, + conn: hijack, + }, nil +} + +func (c *Container) execConfig(r ExecOpts, cmd []string) types.ExecConfig { + env := append(r.Env, fmt.Sprintf("RUNSC_TEST_NAME=%s", c.Name)) + return types.ExecConfig{ + AttachStdin: r.UseTTY, + AttachStderr: true, + AttachStdout: true, + Cmd: cmd, + Privileged: r.Privileged, + WorkingDir: r.WorkDir, + Env: env, + Tty: r.UseTTY, + User: r.User, + } + +} + +// Process represents a containerized process. +type Process struct { + container *Container + execid string + conn types.HijackedResponse +} + +// Write writes buf to the process's stdin. +func (p *Process) Write(timeout time.Duration, buf []byte) (int, error) { + p.conn.Conn.SetDeadline(time.Now().Add(timeout)) + return p.conn.Conn.Write(buf) +} + +// Read returns process's stdout and stderr. +func (p *Process) Read() (string, string, error) { + var stdout, stderr bytes.Buffer + if err := p.read(&stdout, &stderr); err != nil { + return "", "", err + } + return stdout.String(), stderr.String(), nil +} + +// Logs returns combined stdout/stderr from the process. +func (p *Process) Logs() (string, error) { + var out bytes.Buffer + if err := p.read(&out, &out); err != nil { + return "", err + } + return out.String(), nil +} + +func (p *Process) read(stdout, stderr *bytes.Buffer) error { + _, err := stdcopy.StdCopy(stdout, stderr, p.conn.Reader) + return err +} + +// ExitCode returns the process's exit code. +func (p *Process) ExitCode(ctx context.Context) (int, error) { + _, exitCode, err := p.runningExitCode(ctx) + return exitCode, err +} + +// IsRunning checks if the process is running. +func (p *Process) IsRunning(ctx context.Context) (bool, error) { + running, _, err := p.runningExitCode(ctx) + return running, err +} + +// WaitExitStatus until process completes and returns exit status. +func (p *Process) WaitExitStatus(ctx context.Context) (int, error) { + waitChan := make(chan (int)) + errChan := make(chan (error)) + + go func() { + for { + running, exitcode, err := p.runningExitCode(ctx) + if err != nil { + errChan <- fmt.Errorf("error waiting process %s: container %v", p.execid, p.container.Name) + } + if !running { + waitChan <- exitcode + } + time.Sleep(time.Millisecond * 500) + } + }() + + select { + case ws := <-waitChan: + return ws, nil + case err := <-errChan: + return -1, err + } +} + +// runningExitCode collects if the process is running and the exit code. +// The exit code is only valid if the process has exited. +func (p *Process) runningExitCode(ctx context.Context) (bool, int, error) { + // If execid is not empty, this is a execed process. + if p.execid != "" { + status, err := p.container.client.ContainerExecInspect(ctx, p.execid) + return status.Running, status.ExitCode, err + } + // else this is the root process. + status, err := p.container.Status(ctx) + return status.Running, status.ExitCode, err +} diff --git a/pkg/test/dockerutil/network.go b/pkg/test/dockerutil/network.go new file mode 100644 index 000000000..047091e75 --- /dev/null +++ b/pkg/test/dockerutil/network.go @@ -0,0 +1,113 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "net" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/client" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Network is a docker network. +type Network struct { + client *client.Client + id string + logger testutil.Logger + Name string + containers []*Container + Subnet *net.IPNet +} + +// NewNetwork sets up the struct for a Docker network. Names of networks +// will be unique. +func NewNetwork(ctx context.Context, logger testutil.Logger) *Network { + client, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + logger.Logf("create client failed with: %v", err) + return nil + } + client.NegotiateAPIVersion(ctx) + + return &Network{ + logger: logger, + Name: testutil.RandomID(logger.Name()), + client: client, + } +} + +func (n *Network) networkCreate() types.NetworkCreate { + + var subnet string + if n.Subnet != nil { + subnet = n.Subnet.String() + } + + ipam := network.IPAM{ + Config: []network.IPAMConfig{{ + Subnet: subnet, + }}, + } + + return types.NetworkCreate{ + CheckDuplicate: true, + IPAM: &ipam, + } +} + +// Create is analogous to 'docker network create'. +func (n *Network) Create(ctx context.Context) error { + + opts := n.networkCreate() + resp, err := n.client.NetworkCreate(ctx, n.Name, opts) + if err != nil { + return err + } + n.id = resp.ID + return nil +} + +// Connect is analogous to 'docker network connect' with the arguments provided. +func (n *Network) Connect(ctx context.Context, container *Container, ipv4, ipv6 string) error { + settings := network.EndpointSettings{ + IPAMConfig: &network.EndpointIPAMConfig{ + IPv4Address: ipv4, + IPv6Address: ipv6, + }, + } + err := n.client.NetworkConnect(ctx, n.id, container.id, &settings) + if err == nil { + n.containers = append(n.containers, container) + } + return err +} + +// Inspect returns this network's info. +func (n *Network) Inspect(ctx context.Context) (types.NetworkResource, error) { + return n.client.NetworkInspect(ctx, n.id, types.NetworkInspectOptions{Verbose: true}) +} + +// Cleanup cleans up the docker network and all the containers attached to it. +func (n *Network) Cleanup(ctx context.Context) error { + for _, c := range n.containers { + c.CleanUp(ctx) + } + n.containers = nil + + return n.client.NetworkRemove(ctx, n.id) +} diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go new file mode 100644 index 000000000..1fab33083 --- /dev/null +++ b/pkg/test/dockerutil/profile.go @@ -0,0 +1,152 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Profile represents profile-like operations on a container, +// such as running perf or pprof. It is meant to be added to containers +// such that the container type calls the Profile during its lifecycle. +type Profile interface { + // OnCreate is called just after the container is created when the container + // has a valid ID (e.g. c.ID()). + OnCreate(c *Container) error + + // OnStart is called just after the container is started when the container + // has a valid Pid (e.g. c.SandboxPid()). + OnStart(c *Container) error + + // Restart restarts the Profile on request. + Restart(c *Container) error + + // OnCleanUp is called during the container's cleanup method. + // Cleanups should just log errors if they have them. + OnCleanUp(c *Container) error +} + +// Pprof is for running profiles with 'runsc debug'. Pprof workloads +// should be run as root and ONLY against runsc sandboxes. The runtime +// should have --profile set as an option in /etc/docker/daemon.json in +// order for profiling to work with Pprof. +type Pprof struct { + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + GoRoutineProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser +} + +// MakePprofFromFlags makes a Pprof profile from flags. +func MakePprofFromFlags(c *Container) *Pprof { + if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) { + return nil + } + return &Pprof{ + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + GoRoutineProfile: *pprofGo, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, + } +} + +// OnCreate implements Profile.OnCreate. +func (p *Pprof) OnCreate(c *Container) error { + return os.MkdirAll(p.BasePath, 0755) +} + +// OnStart implements Profile.OnStart. +func (p *Pprof) OnStart(c *Container) error { + path, err := RuntimePath() + if err != nil { + return fmt.Errorf("failed to get runtime path: %v", err) + } + + // The root directory of this container's runtime. + root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime) + // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`. + args := []string{root, "debug"} + args = append(args, p.makeProfileArgs(c)...) + args = append(args, c.ID()) + + // Best effort wait until container is running. + for now := time.Now(); time.Since(now) < 5*time.Second; { + if status, err := c.Status(context.Background()); err != nil { + return fmt.Errorf("failed to get status with: %v", err) + + } else if status.Running { + break + } + time.Sleep(500 * time.Millisecond) + } + p.cmd = exec.Command(path, args...) + if err := p.cmd.Start(); err != nil { + return fmt.Errorf("process failed: %v", err) + } + return nil +} + +// Restart implements Profile.Restart. +func (p *Pprof) Restart(c *Container) error { + p.OnCleanUp(c) + return p.OnStart(c) +} + +// OnCleanUp implements Profile.OnCleanup +func (p *Pprof) OnCleanUp(c *Container) error { + defer func() { p.cmd = nil }() + if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() { + return p.cmd.Process.Kill() + } + return nil +} + +// makeProfileArgs turns Pprof fields into runsc debug flags. +func (p *Pprof) makeProfileArgs(c *Container) []string { + var ret []string + if p.BlockProfile { + ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof"))) + } + if p.CPUProfile { + ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) + } + if p.GoRoutineProfile { + ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof"))) + } + if p.HeapProfile { + ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) + } + if p.MutexProfile { + ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof"))) + } + ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration)) + return ret +} diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go new file mode 100644 index 000000000..b7b4d7618 --- /dev/null +++ b/pkg/test/dockerutil/profile_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dockerutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +type testCase struct { + name string + pprof Pprof + expectedFiles []string +} + +func TestPprof(t *testing.T) { + // Basepath and expected file names for each type of profile. + basePath := "/tmp/test/profile" + block := "block.pprof" + cpu := "cpu.pprof" + goprofle := "go.pprof" + heap := "heap.pprof" + mutex := "mutex.pprof" + + testCases := []testCase{ + { + name: "Cpu", + pprof: Pprof{ + BasePath: basePath, + CPUProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{cpu}, + }, + { + name: "All", + pprof: Pprof{ + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + GoRoutineProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, + }, + expectedFiles: []string{block, cpu, goprofle, heap, mutex}, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + c := MakeContainer(ctx, t) + // Set basepath to include the container name so there are no conflicts. + tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name) + c.AddProfile(&tc.pprof) + + func() { + defer c.CleanUp(ctx) + // Start a container. + if err := c.Spawn(ctx, RunOpts{ + Image: "basic/alpine", + }, "sleep", "1000"); err != nil { + t.Fatalf("run failed with: %v", err) + } + + if status, err := c.Status(context.Background()); !status.Running { + t.Fatalf("container is not yet running: %+v err: %v", status, err) + } + + // End early if the expected files exist and have data. + for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) { + if err := checkFiles(tc); err == nil { + break + } + } + }() + + // Check all expected files exist and have data. + if err := checkFiles(tc); err != nil { + t.Fatalf(err.Error()) + } + }) + } +} + +func checkFiles(tc testCase) error { + for _, file := range tc.expectedFiles { + stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file)) + if err != nil { + return fmt.Errorf("stat failed with: %v", err) + } else if stat.Size() < 1 { + return fmt.Errorf("file not written to: %+v", stat) + } + } + return nil +} + +func TestMain(m *testing.M) { + EnsureSupportedDockerVersion() + os.Exit(m.Run()) +} diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index 03b1b4677..2d8f56bc0 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -15,6 +15,6 @@ go_library( "//runsc/boot", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index f21d6769a..64c292698 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -482,6 +482,21 @@ func IsStatic(filename string) (bool, error) { return true, nil } +// TouchShardStatusFile indicates to Bazel that the test runner supports +// sharding by creating or updating the last modified date of the file +// specified by TEST_SHARD_STATUS_FILE. +// +// See https://docs.bazel.build/versions/master/test-encyclopedia.html#role-of-the-test-runner. +func TouchShardStatusFile() error { + if statusFile := os.Getenv("TEST_SHARD_STATUS_FILE"); statusFile != "" { + cmd := exec.Command("touch", statusFile) + if b, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("touch %q failed:\n output: %s\n error: %s", statusFile, string(b), err.Error()) + } + } + return nil +} + // TestIndicesForShard returns indices for this test shard based on the // TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. // diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 707eb085b..67a950444 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -128,13 +128,6 @@ type EntryCallback interface { // // +stateify savable type Entry struct { - // Context stores any state the waiter may wish to store in the entry - // itself, which may be used at wake up time. - // - // Note that use of this field is optional and state may alternatively be - // stored in the callback itself. - Context interface{} - Callback EntryCallback // The following fields are protected by the queue lock. @@ -142,13 +135,14 @@ type Entry struct { waiterEntry } -type channelCallback struct{} +type channelCallback struct { + ch chan struct{} +} // Callback implements EntryCallback.Callback. -func (*channelCallback) Callback(e *Entry) { - ch := e.Context.(chan struct{}) +func (c *channelCallback) Callback(*Entry) { select { - case ch <- struct{}{}: + case c.ch <- struct{}{}: default: } } @@ -164,7 +158,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { c = make(chan struct{}, 1) } - return Entry{Context: c, Callback: &channelCallback{}}, c + return Entry{Callback: &channelCallback{ch: c}}, c } // Queue represents the wait queue where waiters can be added and diff --git a/runsc/BUILD b/runsc/BUILD index 757f6d44c..96f697a5f 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -62,18 +62,22 @@ go_binary( ) pkg_tar( - name = "runsc-bin", - srcs = [":runsc"], + name = "debian-bin", + srcs = [ + ":runsc", + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", + ], mode = "0755", package_dir = "/usr/bin", - strip_prefix = "/runsc/linux_amd64_pure_stripped", ) pkg_tar( name = "debian-data", extension = "tar.gz", deps = [ - ":runsc-bin", + ":debian-bin", + "//shim:config", ], ) diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 2b1e6b13e..9f52438c2 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -40,6 +40,8 @@ go_library( "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/control", "//pkg/sentry/devices/memdev", + "//pkg/sentry/devices/ttydev", + "//pkg/sentry/devices/tundev", "//pkg/sentry/fdimport", "//pkg/sentry/fs", "//pkg/sentry/fs/dev", @@ -53,6 +55,7 @@ go_library( "//pkg/sentry/fs/user", "//pkg/sentry/fsimpl/devpts", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/fuse", "//pkg/sentry/fsimpl/gofer", "//pkg/sentry/fsimpl/host", "//pkg/sentry/fsimpl/overlay", @@ -87,6 +90,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/packetsocket", "//pkg/tcpip/link/qdisc/fifo", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", @@ -103,7 +107,7 @@ go_library( "//runsc/boot/pprof", "//runsc/specutils", "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -128,7 +132,7 @@ go_test( "//pkg/sync", "//pkg/unet", "//runsc/fsgofer", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/boot/config.go b/runsc/boot/config.go index bb01b8fb5..80da8b3e6 100644 --- a/runsc/boot/config.go +++ b/runsc/boot/config.go @@ -274,6 +274,9 @@ type Config struct { // Enables VFS2 (not plumbled through yet). VFS2 bool + + // Enables FUSE usage (not plumbled through yet). + FUSE bool } // ToFlags returns a slice of flags that correspond to the given Config. @@ -325,5 +328,9 @@ func (c *Config) ToFlags() []string { f = append(f, "--vfs2=true") } + if c.FUSE { + f = append(f, "--fuse=true") + } + return f } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 8125d5061..3e5e4c22f 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -155,7 +155,7 @@ func newController(fd int, l *Loader) (*controller, error) { srv.Register(&debug{}) srv.Register(&control.Logging{}) - if l.conf.ProfileEnable { + if l.root.conf.ProfileEnable { srv.Register(&control.Profile{ Kernel: l.k, }) @@ -333,7 +333,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Pause the kernel while we build a new one. cm.l.k.Pause() - p, err := createPlatform(cm.l.conf, deviceFile) + p, err := createPlatform(cm.l.root.conf, deviceFile) if err != nil { return fmt.Errorf("creating platform: %v", err) } @@ -349,8 +349,8 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { cm.l.k = k // Set up the restore environment. - mntr := newContainerMounter(cm.l.spec, cm.l.goferFDs, cm.l.k, cm.l.mountHints) - renv, err := mntr.createRestoreEnvironment(cm.l.conf) + mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints) + renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) if err != nil { return fmt.Errorf("creating RestoreEnvironment: %v", err) } @@ -368,7 +368,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { return fmt.Errorf("file cannot be empty") } - if cm.l.conf.ProfileEnable { + if cm.l.root.conf.ProfileEnable { // pprof.Initialize opens /proc/self/maps, so has to be called before // installing seccomp filters. pprof.Initialize() @@ -387,13 +387,13 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Since we have a new kernel we also must make a new watchdog. dogOpts := watchdog.DefaultOpts - dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction + dogOpts.TaskTimeoutAction = cm.l.root.conf.WatchdogAction dog := watchdog.New(k, dogOpts) // Change the loader fields to reflect the changes made when restoring. cm.l.k = k cm.l.watchdog = dog - cm.l.rootProcArgs = kernel.CreateProcessArgs{} + cm.l.root.procArgs = kernel.CreateProcessArgs{} cm.l.restore = true // Reinitialize the sandbox ID and processes map. Note that it doesn't diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 60e33425f..149eb0b1b 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -310,19 +310,12 @@ var allowedSyscalls = seccomp.SyscallRules{ }, }, syscall.SYS_WRITE: {}, - // The only user in rawfile.NonBlockingWrite3 always passes iovcnt with - // values 2 or 3. Three iovec-s are passed, when the PACKET_VNET_HDR - // option is enabled for a packet socket. + // For rawfile.NonBlockingWriteIovec. syscall.SYS_WRITEV: []seccomp.Rule{ { seccomp.AllowAny{}, seccomp.AllowAny{}, - seccomp.AllowValue(2), - }, - { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(3), + seccomp.GreaterThan(0), }, }, } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index e83584b82..59639ba19 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -29,6 +29,7 @@ import ( _ "gvisor.dev/gvisor/pkg/sentry/fs/sys" _ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs" _ "gvisor.dev/gvisor/pkg/sentry/fs/tty" + "gvisor.dev/gvisor/pkg/sentry/vfs" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" @@ -390,6 +391,10 @@ type mountHint struct { // root is the inode where the volume is mounted. For mounts with 'pod' share // the volume is mounted once and then bind mounted inside the containers. root *fs.Inode + + // vfsMount is the master mount for the volume. For mounts with 'pod' share + // the master volume is bind mounted inside the containers. + vfsMount *vfs.Mount } func (m *mountHint) setField(key, val string) error { @@ -571,9 +576,9 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin // processHints processes annotations that container hints about how volumes // should be mounted (e.g. a volume shared between containers). It must be // called for the root container only. -func (c *containerMounter) processHints(conf *Config) error { +func (c *containerMounter) processHints(conf *Config, creds *auth.Credentials) error { if conf.VFS2 { - return nil + return c.processHintsVFS2(conf, creds) } ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 081db39c1..9cd9c5909 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -77,29 +77,34 @@ import ( _ "gvisor.dev/gvisor/pkg/sentry/socket/unix" ) -// Loader keeps state needed to start the kernel and run the container.. -type Loader struct { - // k is the kernel. - k *kernel.Kernel - - // ctrl is the control server. - ctrl *controller - +type containerInfo struct { conf *Config - // console is set to true if terminal is enabled. - console bool + // spec is the base configuration for the root container. + spec *specs.Spec - watchdog *watchdog.Watchdog + // procArgs refers to the container's init task. + procArgs kernel.CreateProcessArgs // stdioFDs contains stdin, stdout, and stderr. stdioFDs []int // goferFDs are the FDs that attach the sandbox to the gofers. goferFDs []int +} - // spec is the base configuration for the root container. - spec *specs.Spec +// Loader keeps state needed to start the kernel and run the container.. +type Loader struct { + // k is the kernel. + k *kernel.Kernel + + // ctrl is the control server. + ctrl *controller + + // root contains information about the root container in the sandbox. + root containerInfo + + watchdog *watchdog.Watchdog // stopSignalForwarding disables forwarding of signals to the sandboxed // container. It should be called when a sandbox is destroyed. @@ -108,9 +113,6 @@ type Loader struct { // restore is set to true if we are restoring a container. restore bool - // rootProcArgs refers to the root sandbox init task. - rootProcArgs kernel.CreateProcessArgs - // sandboxID is the ID for the whole sandbox. sandboxID string @@ -175,8 +177,6 @@ type Args struct { // StdioFDs is the stdio for the application. The Loader takes ownership of // these FDs and may close them at any time. StdioFDs []int - // Console is set to true if using TTY. - Console bool // NumCPU is the number of CPUs to create inside the sandbox. NumCPU int // TotalMem is the initial amount of total memory to report back to the @@ -205,6 +205,10 @@ func New(args Args) (*Loader, error) { // Is this a VFSv2 kernel? if args.Conf.VFS2 { kernel.VFS2Enabled = true + if args.Conf.FUSE { + kernel.FUSEEnabled = true + } + vfs2.Override() } @@ -227,9 +231,7 @@ func New(args Args) (*Loader, error) { // Create VDSO. // // Pass k as the platform since it is savable, unlike the actual platform. - // - // FIXME(b/109889800): Use non-nil context. - vdso, err := loader.PrepareVDSO(nil, k) + vdso, err := loader.PrepareVDSO(k) if err != nil { return nil, fmt.Errorf("creating vdso: %v", err) } @@ -300,6 +302,12 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("initializing kernel: %v", err) } + if kernel.VFS2Enabled { + if err := registerFilesystems(k); err != nil { + return nil, fmt.Errorf("registering filesystems: %w", err) + } + } + if err := adjustDirentCache(k); err != nil { return nil, err } @@ -318,7 +326,7 @@ func New(args Args) (*Loader, error) { dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction dog := watchdog.New(k, dogOpts) - procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) + procArgs, err := createProcessArgs(args.ID, args.Spec, creds, k, k.RootPIDNamespace()) if err != nil { return nil, fmt.Errorf("creating init process for root container: %v", err) } @@ -366,17 +374,18 @@ func New(args Args) (*Loader, error) { eid := execID{cid: args.ID} l := &Loader{ - k: k, - conf: args.Conf, - console: args.Console, - watchdog: dog, - spec: args.Spec, - goferFDs: args.GoferFDs, - stdioFDs: stdioFDs, - rootProcArgs: procArgs, - sandboxID: args.ID, - processes: map[execID]*execProcess{eid: {}}, - mountHints: mountHints, + k: k, + watchdog: dog, + sandboxID: args.ID, + processes: map[execID]*execProcess{eid: {}}, + mountHints: mountHints, + root: containerInfo{ + conf: args.Conf, + stdioFDs: stdioFDs, + goferFDs: args.GoferFDs, + spec: args.Spec, + procArgs: procArgs, + }, } // We don't care about child signals; some platforms can generate a @@ -404,8 +413,8 @@ func New(args Args) (*Loader, error) { return l, nil } -// newProcess creates a process that can be run with kernel.CreateProcess. -func newProcess(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) { +// createProcessArgs creates args that can be used with kernel.CreateProcess. +func createProcessArgs(id string, spec *specs.Spec, creds *auth.Credentials, k *kernel.Kernel, pidns *kernel.PIDNamespace) (kernel.CreateProcessArgs, error) { // Create initial limits. ls, err := createLimitSet(spec) if err != nil { @@ -479,13 +488,13 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { } func (l *Loader) installSeccompFilters() error { - if l.conf.DisableSeccomp { + if l.root.conf.DisableSeccomp { filter.Report("syscall filter is DISABLED. Running in less secure mode.") } else { opts := filter.Options{ Platform: l.k.Platform, - HostNetwork: l.conf.Network == NetworkHost, - ProfileEnable: l.conf.ProfileEnable, + HostNetwork: l.root.conf.Network == NetworkHost, + ProfileEnable: l.root.conf.ProfileEnable, ControllerFD: l.ctrl.srv.FD(), } if err := filter.Install(opts); err != nil { @@ -511,7 +520,7 @@ func (l *Loader) Run() error { } func (l *Loader) run() error { - if l.conf.Network == NetworkHost { + if l.root.conf.Network == NetworkHost { // Delay host network configuration to this point because network namespace // is configured after the loader is created and before Run() is called. log.Debugf("Configuring host network") @@ -532,10 +541,8 @@ func (l *Loader) run() error { // If we are restoring, we do not want to create a process. // l.restore is set by the container manager when a restore call is made. - var ttyFile *host.TTYFileOperations - var ttyFileVFS2 *hostvfs2.TTYFileDescription if !l.restore { - if l.conf.ProfileEnable { + if l.root.conf.ProfileEnable { pprof.Initialize() } @@ -545,82 +552,29 @@ func (l *Loader) run() error { return err } - // Create the FD map, which will set stdin, stdout, and stderr. If console - // is true, then ioctl calls will be passed through to the host fd. - ctx := l.rootProcArgs.NewContext(l.k) - var err error - - // CreateProcess takes a reference on FDMap if successful. We won't need - // ours either way. - l.rootProcArgs.FDTable, ttyFile, ttyFileVFS2, err = createFDTable(ctx, l.console, l.stdioFDs) - if err != nil { - return fmt.Errorf("importing fds: %v", err) - } - - // Setup the root container file system. - l.startGoferMonitor(l.sandboxID, l.goferFDs) - - mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints) - if err := mntr.processHints(l.conf); err != nil { - return err - } - if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil { - return err - } - - // Add the HOME enviroment variable if it is not already set. - var envv []string - if kernel.VFS2Enabled { - envv, err = user.MaybeAddExecUserHomeVFS2(ctx, l.rootProcArgs.MountNamespaceVFS2, - l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) - - } else { - envv, err = user.MaybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, - l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv) - } - if err != nil { - return err - } - l.rootProcArgs.Envv = envv - // Create the root container init task. It will begin running // when the kernel is started. - if _, _, err := l.k.CreateProcess(l.rootProcArgs); err != nil { - return fmt.Errorf("creating init process: %v", err) + if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil { + return err } - - // CreateProcess takes a reference on FDTable if successful. - l.rootProcArgs.FDTable.DecRef() } ep.tg = l.k.GlobalInit() - if ns, ok := specutils.GetNS(specs.PIDNamespace, l.spec); ok { + if ns, ok := specutils.GetNS(specs.PIDNamespace, l.root.spec); ok { ep.pidnsPath = ns.Path } - if l.console { - // Set the foreground process group on the TTY to the global init process - // group, since that is what we are about to start running. - switch { - case ttyFileVFS2 != nil: - ep.ttyVFS2 = ttyFileVFS2 - ttyFileVFS2.InitForegroundProcessGroup(ep.tg.ProcessGroup()) - case ttyFile != nil: - ep.tty = ttyFile - ttyFile.InitForegroundProcessGroup(ep.tg.ProcessGroup()) - } - } // Handle signals by forwarding them to the root container process // (except for panic signal, which should cause a panic). l.stopSignalForwarding = sighandling.StartSignalForwarding(func(sig linux.Signal) { // Panic signal should cause a panic. - if l.conf.PanicSignal != -1 && sig == linux.Signal(l.conf.PanicSignal) { + if l.root.conf.PanicSignal != -1 && sig == linux.Signal(l.root.conf.PanicSignal) { panic("Signal-induced panic") } // Otherwise forward to root container. deliveryMode := DeliverToProcess - if l.console { + if l.root.spec.Process.Terminal { // Since we are running with a console, we should forward the signal to // the foreground process group so that job control signals like ^C can // be handled properly. @@ -637,7 +591,7 @@ func (l *Loader) run() error { // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the // passed FDs, so only close for VFS1. if !kernel.VFS2Enabled { - for _, fd := range l.stdioFDs { + for _, fd := range l.root.stdioFDs { err := syscall.Close(fd) if err != nil { return fmt.Errorf("close dup()ed stdioFDs: %v", err) @@ -676,8 +630,8 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file l.mu.Lock() defer l.mu.Unlock() - eid := execID{cid: cid} - if _, ok := l.processes[eid]; !ok { + ep := l.processes[execID{cid: cid}] + if ep == nil { return fmt.Errorf("trying to start a deleted container %q", cid) } @@ -711,76 +665,112 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file if pidns == nil { pidns = l.k.RootPIDNamespace().NewChild(l.k.RootUserNamespace()) } - l.processes[eid].pidnsPath = ns.Path + ep.pidnsPath = ns.Path } else { pidns = l.k.RootPIDNamespace() } - procArgs, err := newProcess(cid, spec, creds, l.k, pidns) + + info := &containerInfo{ + conf: conf, + spec: spec, + } + info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns) if err != nil { return fmt.Errorf("creating new process: %v", err) } // setupContainerFS() dups stdioFDs, so we don't need to dup them here. - var stdioFDs []int for _, f := range files[:3] { - stdioFDs = append(stdioFDs, int(f.Fd())) + info.stdioFDs = append(info.stdioFDs, int(f.Fd())) } - // Create the FD map, which will set stdin, stdout, and stderr. - ctx := procArgs.NewContext(l.k) - fdTable, _, _, err := createFDTable(ctx, false, stdioFDs) - if err != nil { - return fmt.Errorf("importing fds: %v", err) - } - // CreateProcess takes a reference on fdTable if successful. We won't - // need ours either way. - procArgs.FDTable = fdTable - // Can't take ownership away from os.File. dup them to get a new FDs. - var goferFDs []int for _, f := range files[3:] { fd, err := syscall.Dup(int(f.Fd())) if err != nil { return fmt.Errorf("failed to dup file: %v", err) } - goferFDs = append(goferFDs, fd) + info.goferFDs = append(info.goferFDs, fd) + } + + tg, err := l.createContainerProcess(false, cid, info, ep) + if err != nil { + return err + } + + // Success! + l.k.StartProcess(tg) + ep.tg = tg + return nil +} + +func (l *Loader) createContainerProcess(root bool, cid string, info *containerInfo, ep *execProcess) (*kernel.ThreadGroup, error) { + console := false + if root { + // Only root container supports terminal for now. + console = info.spec.Process.Terminal } + // Create the FD map, which will set stdin, stdout, and stderr. + ctx := info.procArgs.NewContext(l.k) + fdTable, ttyFile, ttyFileVFS2, err := createFDTable(ctx, console, info.stdioFDs) + if err != nil { + return nil, fmt.Errorf("importing fds: %v", err) + } + // CreateProcess takes a reference on fdTable if successful. We won't need + // ours either way. + info.procArgs.FDTable = fdTable + // Setup the child container file system. - l.startGoferMonitor(cid, goferFDs) + l.startGoferMonitor(cid, info.goferFDs) - mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints) - if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil { - return err + mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints) + if root { + if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil { + return nil, err + } + } + if err := setupContainerFS(ctx, info.conf, mntr, &info.procArgs); err != nil { + return nil, err } // Add the HOME enviroment variable if it is not already set. var envv []string if kernel.VFS2Enabled { - envv, err = user.MaybeAddExecUserHomeVFS2(ctx, procArgs.MountNamespaceVFS2, - procArgs.Credentials.RealKUID, procArgs.Envv) + envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2, + info.procArgs.Credentials.RealKUID, info.procArgs.Envv) } else { - envv, err = user.MaybeAddExecUserHome(ctx, procArgs.MountNamespace, - procArgs.Credentials.RealKUID, procArgs.Envv) + envv, err = user.MaybeAddExecUserHome(ctx, info.procArgs.MountNamespace, + info.procArgs.Credentials.RealKUID, info.procArgs.Envv) } if err != nil { - return err + return nil, err } - procArgs.Envv = envv + info.procArgs.Envv = envv // Create and start the new process. - tg, _, err := l.k.CreateProcess(procArgs) + tg, _, err := l.k.CreateProcess(info.procArgs) if err != nil { - return fmt.Errorf("creating process: %v", err) + return nil, fmt.Errorf("creating process: %v", err) } - l.k.StartProcess(tg) - // CreateProcess takes a reference on FDTable if successful. - procArgs.FDTable.DecRef() + info.procArgs.FDTable.DecRef() - l.processes[eid].tg = tg - return nil + // Set the foreground process group on the TTY to the global init process + // group, since that is what we are about to start running. + if root { + switch { + case ttyFileVFS2 != nil: + ep.ttyVFS2 = ttyFileVFS2 + ttyFileVFS2.InitForegroundProcessGroup(tg.ProcessGroup()) + case ttyFile != nil: + ep.tty = ttyFile + ttyFile.InitForegroundProcessGroup(tg.ProcessGroup()) + } + } + + return tg, nil } // startGoferMonitor runs a goroutine to monitor gofer's health. It polls on @@ -1058,7 +1048,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(true)); err != nil { + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { return nil, fmt.Errorf("failed to enable SACK: %s", err) } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index e448fd773..8e6fe57e1 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -479,13 +479,13 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { defer l.Destroy() defer loaderCleanup() - mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints) - if err := mntr.processHints(l.conf); err != nil { + mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints) + if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil { t.Fatalf("failed process hints: %v", err) } ctx := l.k.SupervisorContext() - mns, err := mntr.setupVFS2(ctx, l.conf, &l.rootProcArgs) + mns, err := mntr.setupVFS2(ctx, l.root.conf, &l.root.procArgs) if err != nil { t.Fatalf("failed to setupVFS2: %v", err) } @@ -499,7 +499,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { Path: fspath.Parse(p), } - if d, err := l.k.VFS().GetDentryAt(ctx, l.rootProcArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil { + if d, err := l.k.VFS().GetDentryAt(ctx, l.root.procArgs.Credentials, target, &vfs.GetDentryOptions{}); err != nil { t.Errorf("expected path %v to exist with spec %v, but got error %v", p, tc.spec, err) } else { d.DecRef() diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 14d2f56a5..4e1fa7665 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/packetsocket" "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -252,6 +253,9 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) } + // Enable support for AF_PACKET sockets to receive outgoing packets. + linkEP = packetsocket.New(linkEP) + log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels) if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil { return err diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index d1653b279..cfe2d36aa 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -26,9 +26,12 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/devices/memdev" + "gvisor.dev/gvisor/pkg/sentry/devices/ttydev" + "gvisor.dev/gvisor/pkg/sentry/devices/tundev" "gvisor.dev/gvisor/pkg/sentry/fs/user" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse" "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer" "gvisor.dev/gvisor/pkg/sentry/fsimpl/overlay" "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc" @@ -40,7 +43,11 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) error { +func registerFilesystems(k *kernel.Kernel) error { + ctx := k.SupervisorContext() + creds := auth.NewRootCredentials(k.RootUserNamespace()) + vfsObj := k.VFS() + vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, // TODO(b/29356795): Users may mount this once the terminals are in a @@ -70,11 +77,28 @@ func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, cre AllowUserMount: true, AllowUserList: true, }) + vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) // Setup files in devtmpfs. if err := memdev.Register(vfsObj); err != nil { return fmt.Errorf("registering memdev: %w", err) } + if err := ttydev.Register(vfsObj); err != nil { + return fmt.Errorf("registering ttydev: %w", err) + } + + if kernel.FUSEEnabled { + if err := fuse.Register(vfsObj); err != nil { + return fmt.Errorf("registering fusedev: %w", err) + } + } + + if err := tundev.Register(vfsObj); err != nil { + return fmt.Errorf("registering tundev: %v", err) + } a, err := devtmpfs.NewAccessor(ctx, vfsObj, creds, devtmpfs.Name) if err != nil { return fmt.Errorf("creating devtmpfs accessor: %w", err) @@ -85,15 +109,25 @@ func registerFilesystems(ctx context.Context, vfsObj *vfs.VirtualFilesystem, cre return fmt.Errorf("initializing userspace: %w", err) } if err := memdev.CreateDevtmpfsFiles(ctx, a); err != nil { - return fmt.Errorf("creating devtmpfs files: %w", err) + return fmt.Errorf("creating memdev devtmpfs files: %w", err) + } + if err := ttydev.CreateDevtmpfsFiles(ctx, a); err != nil { + return fmt.Errorf("creating ttydev devtmpfs files: %w", err) + } + if err := tundev.CreateDevtmpfsFiles(ctx, a); err != nil { + return fmt.Errorf("creating tundev devtmpfs files: %v", err) + } + + if kernel.FUSEEnabled { + if err := fuse.CreateDevtmpfsFile(ctx, a); err != nil { + return fmt.Errorf("creating fusedev devtmpfs files: %w", err) + } } + return nil } func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { - if err := mntr.k.VFS().Init(); err != nil { - return fmt.Errorf("failed to initialize VFS: %w", err) - } mns, err := mntr.setupVFS2(ctx, conf, procArgs) if err != nil { return fmt.Errorf("failed to setupFS: %w", err) @@ -122,10 +156,6 @@ func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals rootCtx := procArgs.NewContext(c.k) - if err := registerFilesystems(rootCtx, c.k.VFS(), rootCreds); err != nil { - return nil, fmt.Errorf("register filesystems: %w", err) - } - mns, err := c.createMountNamespaceVFS2(rootCtx, conf, rootCreds) if err != nil { return nil, fmt.Errorf("creating mount namespace: %w", err) @@ -141,10 +171,19 @@ func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *Config, creds *auth.Credentials) (*vfs.MountNamespace, error) { fd := c.fds.remove() - opts := strings.Join(p9MountData(fd, conf.FileAccess, true /* vfs2 */), ",") + opts := p9MountData(fd, conf.FileAccess, true /* vfs2 */) + + if conf.OverlayfsStaleRead { + // We can't check for overlayfs here because sandbox is chroot'ed and gofer + // can only send mount options for specs.Mounts (specs.Root is missing + // Options field). So assume root is always on top of overlayfs. + opts = append(opts, "overlayfs_stale_read") + } log.Infof("Mounting root over 9P, ioFD: %d", fd) - mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{Data: opts}) + mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{ + Data: strings.Join(opts, ","), + }) if err != nil { return nil, fmt.Errorf("setting up mount namespace: %w", err) } @@ -160,8 +199,14 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, for i := range mounts { submount := &mounts[i] log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) - if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil { - return err + if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() { + if err := c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint); err != nil { + return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err) + } + } else { + if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil { + return fmt.Errorf("mount submount %q: %w", submount.Destination, err) + } } } @@ -235,20 +280,18 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, // getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values // used for mounts. func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndFD) (string, *vfs.MountOptions, error) { - var ( - fsName string - data []string - ) + fsName := m.Type + var data []string // Find filesystem name and FS specific data field. switch m.Type { case devpts.Name, devtmpfs.Name, proc.Name, sys.Name: - fsName = m.Type + // Nothing to do. + case nonefs: fsName = sys.Name - case tmpfs.Name: - fsName = m.Type + case tmpfs.Name: var err error data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...) if err != nil { @@ -257,10 +300,16 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF case bind: fsName = gofer.Name + if m.fd == 0 { + // Check that an FD was provided to fails fast. Technically FD=0 is valid, + // but unlikely to be correct in this context. + return "", nil, fmt.Errorf("9P mount requires a connection FD") + } data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) default: log.Warningf("ignoring unknown filesystem type %q", m.Type) + return "", nil, nil } opts := &vfs.MountOptions{ @@ -300,7 +349,7 @@ func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath s } _, err := c.k.VFS().StatAt(ctx, creds, target, &vfs.StatOptions{}) if err == nil { - // Mount point exists, nothing else to do. + log.Debugf("Mount point %q already exists", currentPath) return nil } if err != syserror.ENOENT { @@ -378,3 +427,76 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds return fmt.Errorf(`stating "/tmp" inside container: %w`, err) } } + +// processHintsVFS2 processes annotations that container hints about how volumes +// should be mounted (e.g. a volume shared between containers). It must be +// called for the root container only. +func (c *containerMounter) processHintsVFS2(conf *Config, creds *auth.Credentials) error { + ctx := c.k.SupervisorContext() + for _, hint := range c.hints.mounts { + // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a + // common gofer to mount all shared volumes. + if hint.mount.Type != tmpfs.Name { + continue + } + + log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type) + mnt, err := c.mountSharedMasterVFS2(ctx, conf, hint, creds) + if err != nil { + return fmt.Errorf("mounting shared master %q: %v", hint.name, err) + } + hint.vfsMount = mnt + } + return nil +} + +// mountSharedMasterVFS2 mounts the master of a volume that is shared among +// containers in a pod. +func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { + // Map mount type to filesystem name, and parse out the options that we are + // capable of dealing with. + mntFD := &mountAndFD{Mount: hint.mount} + fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, mntFD) + if err != nil { + return nil, err + } + if len(fsName) == 0 { + return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type) + } + return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts) +} + +// mountSharedSubmount binds mount to a previously mounted volume that is shared +// among containers in the same pod. +func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) error { + if err := source.checkCompatible(mount); err != nil { + return err + } + + _, opts, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) + if err != nil { + return err + } + newMnt, err := c.k.VFS().NewDisconnectedMount(source.vfsMount.Filesystem(), source.vfsMount.Root(), opts) + if err != nil { + return err + } + defer newMnt.DecRef() + + root := mns.Root() + defer root.DecRef() + if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil { + return err + } + + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(mount.Destination), + } + if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil { + return err + } + log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name) + return nil +} diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD index 7e34a284a..37f4253ba 100644 --- a/runsc/cgroup/BUILD +++ b/runsc/cgroup/BUILD @@ -10,7 +10,7 @@ go_library( "//pkg/cleanup", "//pkg/log", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) @@ -22,6 +22,6 @@ go_test( tags = ["local"], deps = [ "//pkg/test/testutil", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index e5cc9d622..8fbc3887a 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -92,7 +92,17 @@ func setOptionalValueUint16(path, name string, val *uint16) error { func setValue(path, name, data string) error { fullpath := filepath.Join(path, name) - return ioutil.WriteFile(fullpath, []byte(data), 0700) + + // Retry writes on EINTR; see: + // https://github.com/golang/go/issues/38033 + for { + err := ioutil.WriteFile(fullpath, []byte(data), 0700) + if err == nil { + return nil + } else if !errors.Is(err, syscall.EINTR) { + return err + } + } } func getValue(path, name string) (string, error) { @@ -132,8 +142,16 @@ func fillFromAncestor(path string) (string, error) { if err != nil { return "", err } - if err := ioutil.WriteFile(path, []byte(val), 0700); err != nil { - return "", err + + // Retry writes on EINTR; see: + // https://github.com/golang/go/issues/38033 + for { + err := ioutil.WriteFile(path, []byte(val), 0700) + if err == nil { + break + } else if !errors.Is(err, syscall.EINTR) { + return "", err + } } return val, nil } diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index af3538ef0..1b5178dd5 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -45,7 +45,7 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/platform", - "//pkg/state", + "//pkg/state/pretty", "//pkg/state/statefile", "//pkg/sync", "//pkg/unet", @@ -58,7 +58,7 @@ go_library( "//runsc/fsgofer/filter", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -87,9 +87,9 @@ go_test( "//runsc/boot", "//runsc/container", "//runsc/specutils", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", ], ) diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go index 01204ab4d..f4f247721 100644 --- a/runsc/cmd/boot.go +++ b/runsc/cmd/boot.go @@ -54,10 +54,6 @@ type Boot struct { // provided in that order. stdioFDs intFlags - // console is set to true if the sandbox should allow terminal ioctl(2) - // syscalls. - console bool - // applyCaps determines if capabilities defined in the spec should be applied // to the process. applyCaps bool @@ -115,7 +111,6 @@ func (b *Boot) SetFlags(f *flag.FlagSet) { f.IntVar(&b.deviceFD, "device-fd", -1, "FD for the platform device file") f.Var(&b.ioFDs, "io-fds", "list of FDs to connect 9P clients. They must follow this order: root first, then mounts as defined in the spec") f.Var(&b.stdioFDs, "stdio-fds", "list of FDs containing sandbox stdin, stdout, and stderr in that order") - f.BoolVar(&b.console, "console", false, "set to true if the sandbox should allow terminal ioctl(2) syscalls") f.BoolVar(&b.applyCaps, "apply-caps", false, "if true, apply capabilities defined in the spec to the process") f.BoolVar(&b.setUpRoot, "setup-root", false, "if true, set up an empty root for the process") f.BoolVar(&b.pidns, "pidns", false, "if true, the sandbox is in its own PID namespace") @@ -229,7 +224,6 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Device: os.NewFile(uintptr(b.deviceFD), "platform device"), GoferFDs: b.ioFDs.GetArray(), StdioFDs: b.stdioFDs.GetArray(), - Console: b.console, NumCPU: b.cpuNum, TotalMem: b.totalMem, UserLogFD: b.userLogFD, diff --git a/runsc/cmd/spec.go b/runsc/cmd/spec.go index a2b0a4b14..55194e641 100644 --- a/runsc/cmd/spec.go +++ b/runsc/cmd/spec.go @@ -16,124 +16,122 @@ package cmd import ( "context" - "fmt" - "io/ioutil" + "encoding/json" + "io" "os" "path/filepath" "github.com/google/subcommands" + specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/runsc/flag" ) -func genSpec(cwd string) []byte { - var template = fmt.Sprintf(`{ - "ociVersion": "1.0.0", - "process": { - "terminal": true, - "user": { - "uid": 0, - "gid": 0 - }, - "args": [ - "sh" - ], - "env": [ - "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", - "TERM=xterm" - ], - "cwd": "%s", - "capabilities": { - "bounding": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "effective": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "inheritable": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "permitted": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ], - "ambient": [ - "CAP_AUDIT_WRITE", - "CAP_KILL", - "CAP_NET_BIND_SERVICE" - ] - }, - "rlimits": [ - { - "type": "RLIMIT_NOFILE", - "hard": 1024, - "soft": 1024 - } - ] - }, - "root": { - "path": "rootfs", - "readonly": true - }, - "hostname": "runsc", - "mounts": [ - { - "destination": "/proc", - "type": "proc", - "source": "proc" +func writeSpec(w io.Writer, cwd string, netns string, args []string) error { + spec := &specs.Spec{ + Version: "1.0.0", + Process: &specs.Process{ + Terminal: true, + User: specs.User{ + UID: 0, + GID: 0, + }, + Args: args, + Env: []string{ + "PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin", + "TERM=xterm", + }, + Cwd: cwd, + Capabilities: &specs.LinuxCapabilities{ + Bounding: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Effective: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Inheritable: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + Permitted: []string{ + "CAP_AUDIT_WRITE", + "CAP_KILL", + "CAP_NET_BIND_SERVICE", + }, + // TODO(gvisor.dev/issue/3166): support ambient capabilities + }, + Rlimits: []specs.POSIXRlimit{ + { + Type: "RLIMIT_NOFILE", + Hard: 1024, + Soft: 1024, + }, + }, }, - { - "destination": "/dev", - "type": "tmpfs", - "source": "tmpfs", - "options": [] + Root: &specs.Root{ + Path: "rootfs", + Readonly: true, }, - { - "destination": "/sys", - "type": "sysfs", - "source": "sysfs", - "options": [ - "nosuid", - "noexec", - "nodev", - "ro" - ] - } - ], - "linux": { - "namespaces": [ + Hostname: "runsc", + Mounts: []specs.Mount{ { - "type": "pid" + Destination: "/proc", + Type: "proc", + Source: "proc", }, { - "type": "network" + Destination: "/dev", + Type: "tmpfs", + Source: "tmpfs", }, { - "type": "ipc" + Destination: "/sys", + Type: "sysfs", + Source: "sysfs", + Options: []string{ + "nosuid", + "noexec", + "nodev", + "ro", + }, }, - { - "type": "uts" + }, + Linux: &specs.Linux{ + Namespaces: []specs.LinuxNamespace{ + { + Type: "pid", + }, + { + Type: "network", + Path: netns, + }, + { + Type: "ipc", + }, + { + Type: "uts", + }, + { + Type: "mount", + }, }, - { - "type": "mount" - } - ] + }, } -}`, cwd) - return []byte(template) + e := json.NewEncoder(w) + e.SetIndent("", " ") + return e.Encode(spec) } // Spec implements subcommands.Command for the "spec" command. type Spec struct { bundle string cwd string + netns string } // Name implements subcommands.Command.Name. @@ -148,21 +146,26 @@ func (*Spec) Synopsis() string { // Usage implements subcommands.Command.Usage. func (*Spec) Usage() string { - return `spec [options] - create a new OCI bundle specification file. + return `spec [options] [-- args...] - create a new OCI bundle specification file. + +The spec command creates a new specification file (config.json) for a new OCI +bundle. -The spec command creates a new specification file (config.json) for a new OCI bundle. +The specification file is a starter file that runs the command specified by +'args' in the container. If 'args' is not specified the default is to run the +'sh' program. -The specification file is a starter file that runs the "sh" command in the container. You -should edit the file to suit your needs. You can find out more about the format of the -specification file by visiting the OCI runtime spec repository: +While a number of flags are provided to change values in the specification, you +can examine the file and edit it to suit your needs after this command runs. +You can find out more about the format of the specification file by visiting +the OCI runtime spec repository: https://github.com/opencontainers/runtime-spec/ EXAMPLE: $ mkdir -p bundle/rootfs $ cd bundle - $ runsc spec + $ runsc spec -- /hello $ docker export $(docker create hello-world) | tar -xf - -C rootfs - $ sed -i 's;"sh";"/hello";' config.json $ sudo runsc run hello ` @@ -173,18 +176,29 @@ func (s *Spec) SetFlags(f *flag.FlagSet) { f.StringVar(&s.bundle, "bundle", ".", "path to the root of the OCI bundle") f.StringVar(&s.cwd, "cwd", "/", "working directory that will be set for the executable, "+ "this value MUST be an absolute path") + f.StringVar(&s.netns, "netns", "", "network namespace path") } // Execute implements subcommands.Command.Execute. func (s *Spec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + // Grab the arguments. + containerArgs := f.Args() + if len(containerArgs) == 0 { + containerArgs = []string{"sh"} + } + confPath := filepath.Join(s.bundle, "config.json") if _, err := os.Stat(confPath); !os.IsNotExist(err) { Fatalf("file %q already exists", confPath) } - var spec = genSpec(s.cwd) + configFile, err := os.OpenFile(confPath, os.O_WRONLY|os.O_CREATE, 0664) + if err != nil { + Fatalf("opening file %q: %v", confPath, err) + } - if err := ioutil.WriteFile(confPath, spec, 0664); err != nil { + err = writeSpec(configFile, s.cwd, s.netns, containerArgs) + if err != nil { Fatalf("writing to %q: %v", confPath, err) } diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go index e6f1907da..daed9e728 100644 --- a/runsc/cmd/statefile.go +++ b/runsc/cmd/statefile.go @@ -20,7 +20,7 @@ import ( "os" "github.com/google/subcommands" - "gvisor.dev/gvisor/pkg/state" + "gvisor.dev/gvisor/pkg/state/pretty" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/runsc/flag" ) @@ -105,8 +105,14 @@ func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interfac if err != nil { Fatalf("error parsing statefile: %v", err) } - if err := state.PrettyPrint(output, rc, s.html); err != nil { - Fatalf("error printing state: %v", err) + if s.html { + if err := pretty.PrintHTML(output, rc); err != nil { + Fatalf("error printing state: %v", err) + } + } else { + if err := pretty.PrintText(output, rc); err != nil { + Fatalf("error printing state: %v", err) + } } return subcommands.ExitSuccess } diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 49cfb0837..9a9ee7e2a 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -27,7 +27,7 @@ go_library( "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_gofrs_flock//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], ) @@ -68,7 +68,7 @@ go_test( "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_kr_pty//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 3813c6b93..995d4e267 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -122,6 +122,7 @@ func TestConsoleSocket(t *testing.T) { for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { spec := testutil.NewSpecWithArgs("true") + spec.Process.Terminal = true _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) diff --git a/runsc/container/container.go b/runsc/container/container.go index 6d297d0df..7ad09bf23 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -324,7 +324,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { } } if err := runInCgroup(cg, func() error { - ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir) + ioFiles, specFile, err := c.createGoferProcess(args.Spec, conf, args.BundleDir, args.Attached) if err != nil { return err } @@ -427,7 +427,7 @@ func (c *Container) Start(conf *boot.Config) error { // the start (and all their children processes). if err := runInCgroup(c.Sandbox.Cgroup, func() error { // Create the gofer process. - ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir) + ioFiles, mountsFile, err := c.createGoferProcess(c.Spec, conf, c.BundleDir, false) if err != nil { return err } @@ -861,7 +861,7 @@ func (c *Container) waitForStopped() error { return backoff.Retry(op, b) } -func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string) ([]*os.File, *os.File, error) { +func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string, attached bool) ([]*os.File, *os.File, error) { // Start with the general config flags. args := conf.ToFlags() @@ -955,6 +955,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund cmd.ExtraFiles = goferEnds cmd.Args[0] = "runsc-gofer" + if attached { + // The gofer is attached to the lifetime of this process, so it + // should synchronously die when this process dies. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGKILL, + } + } + // Enter new namespaces to isolate from the rest of the system. Don't unshare // cgroup because gofer is added to a cgroup in the caller's namespace. nss := []specs.LinuxNamespace{ diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index cd76645bd..5e8247bc8 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -643,7 +643,9 @@ func TestExec(t *testing.T) { if err != nil { t.Fatalf("error creating temporary directory: %v", err) } - cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir) + // Note that some shells may exec the final command in a sequence as + // an optimization. We avoid this here by adding the exit 0. + cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100 && exit 0", dir) spec := testutil.NewSpecWithArgs("sh", "-c", cmd) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index a27a01942..e189648f4 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -100,19 +100,20 @@ type execDesc struct { c *Container cmd []string want int - desc string + name string } -func execMany(execs []execDesc) error { +func execMany(t *testing.T, execs []execDesc) { for _, exec := range execs { - args := &control.ExecArgs{Argv: exec.cmd} - if ws, err := exec.c.executeSync(args); err != nil { - return fmt.Errorf("error executing %+v: %v", args, err) - } else if ws.ExitStatus() != exec.want { - return fmt.Errorf("%q: exec %q got exit status: %d, want: %d", exec.desc, exec.cmd, ws.ExitStatus(), exec.want) - } + t.Run(exec.name, func(t *testing.T) { + args := &control.ExecArgs{Argv: exec.cmd} + if ws, err := exec.c.executeSync(args); err != nil { + t.Errorf("error executing %+v: %v", args, err) + } else if ws.ExitStatus() != exec.want { + t.Errorf("%q: exec %q got exit status: %d, want: %d", exec.name, exec.cmd, ws.ExitStatus(), exec.want) + } + }) } - return nil } func createSharedMount(mount specs.Mount, name string, pod ...*specs.Spec) { @@ -1072,7 +1073,7 @@ func TestMultiContainerContainerDestroyStress(t *testing.T) { // Test that pod shared mounts are properly mounted in 2 containers and that // changes from one container is reflected in the other. func TestMultiContainerSharedMount(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1110,84 +1111,82 @@ func TestMultiContainerSharedMount(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", + name: "directory is mounted in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", + name: "directory is mounted in container1", }, { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", + cmd: []string{"/bin/touch", file0}, + name: "create file in container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", + name: "file appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", + name: "file appears in container1", }, { c: containers[1], cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", + name: "remove file from container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", + name: "file removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, { c: containers[1], cmd: []string{"/bin/mkdir", file1}, - desc: "create directory in container1", + name: "create directory in container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-d", file0}, - desc: "dir appears in container0", + name: "dir appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", file1}, - desc: "dir appears in container1", + name: "dir appears in container1", }, { c: containers[0], cmd: []string{"/bin/rmdir", file0}, - desc: "create directory in container0", + name: "remove directory from container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-d", file0}, - desc: "dir removed from container0", + name: "dir removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-d", file1}, - desc: "dir removed from container1", + name: "dir removed from container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } // Test that pod mounts are mounted as readonly when requested. func TestMultiContainerSharedMountReadonly(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1225,35 +1224,34 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", + name: "directory is mounted in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", + name: "directory is mounted in container1", }, { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, + cmd: []string{"/bin/touch", file0}, want: 1, - desc: "fails to write to container0", + name: "fails to write to container0", }, { c: containers[1], - cmd: []string{"/usr/bin/touch", file1}, + cmd: []string{"/bin/touch", file1}, want: 1, - desc: "fails to write to container1", + name: "fails to write to container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } // Test that shared pod mounts continue to work after container is restarted. func TestMultiContainerSharedMountRestart(t *testing.T) { + //TODO(gvisor.dev/issue/1487): This is failing with VFS2. for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() @@ -1291,23 +1289,21 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { execs := []execDesc{ { c: containers[0], - cmd: []string{"/usr/bin/touch", file0}, - desc: "create file in container0", + cmd: []string{"/bin/touch", file0}, + name: "create file in container0", }, { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file appears in container0", + name: "file appears in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file appears in container1", + name: "file appears in container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) containers[1].Destroy() @@ -1334,32 +1330,30 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { { c: containers[0], cmd: []string{"/usr/bin/test", "-f", file0}, - desc: "file is still in container0", + name: "file is still in container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "-f", file1}, - desc: "file is still in container1", + name: "file is still in container1", }, { c: containers[1], cmd: []string{"/bin/rm", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, { c: containers[0], cmd: []string{"/usr/bin/test", "!", "-f", file0}, - desc: "file removed from container0", + name: "file removed from container0", }, { c: containers[1], cmd: []string{"/usr/bin/test", "!", "-f", file1}, - desc: "file removed from container1", + name: "file removed from container1", }, } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) - } + execMany(t, execs) }) } } @@ -1367,53 +1361,53 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { // Test that unsupported pod mounts options are ignored when matching master and // slave mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { - rootDir, cleanup, err := testutil.SetupRootDir() - if err != nil { - t.Fatalf("error creating root dir: %v", err) - } - defer cleanup() - - conf := testutil.TestConfig(t) - conf.RootDir = rootDir + for name, conf := range configsWithVFS2(t, all...) { + t.Run(name, func(t *testing.T) { + rootDir, cleanup, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer cleanup() + conf.RootDir = rootDir - // Setup the containers. - sleep := []string{"/bin/sleep", "100"} - podSpec, ids := createSpecs(sleep, sleep) - mnt0 := specs.Mount{ - Destination: "/mydir/test", - Source: "/some/dir", - Type: "tmpfs", - Options: []string{"rw", "rbind", "relatime"}, - } - podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) + // Setup the containers. + sleep := []string{"/bin/sleep", "100"} + podSpec, ids := createSpecs(sleep, sleep) + mnt0 := specs.Mount{ + Destination: "/mydir/test", + Source: "/some/dir", + Type: "tmpfs", + Options: []string{"rw", "rbind", "relatime"}, + } + podSpec[0].Mounts = append(podSpec[0].Mounts, mnt0) - mnt1 := mnt0 - mnt1.Destination = "/mydir2/test2" - mnt1.Options = []string{"rw", "nosuid"} - podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) + mnt1 := mnt0 + mnt1.Destination = "/mydir2/test2" + mnt1.Options = []string{"rw", "nosuid"} + podSpec[1].Mounts = append(podSpec[1].Mounts, mnt1) - createSharedMount(mnt0, "test-mount", podSpec...) + createSharedMount(mnt0, "test-mount", podSpec...) - containers, cleanup, err := startContainers(conf, podSpec, ids) - if err != nil { - t.Fatalf("error starting containers: %v", err) - } - defer cleanup() + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() - execs := []execDesc{ - { - c: containers[0], - cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, - desc: "directory is mounted in container0", - }, - { - c: containers[1], - cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, - desc: "directory is mounted in container1", - }, - } - if err := execMany(execs); err != nil { - t.Fatal(err.Error()) + execs := []execDesc{ + { + c: containers[0], + cmd: []string{"/usr/bin/test", "-d", mnt0.Destination}, + name: "directory is mounted in container0", + }, + { + c: containers[1], + cmd: []string{"/usr/bin/test", "-d", mnt1.Destination}, + name: "directory is mounted in container1", + }, + } + execMany(t, execs) + }) } } diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh index dc7aeee87..d1e28e17b 100755 --- a/runsc/debian/postinst.sh +++ b/runsc/debian/postinst.sh @@ -18,7 +18,14 @@ if [ "$1" != configure ]; then exit 0 fi +# Update docker configuration. if [ -f /etc/docker/daemon.json ]; then runsc install - systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 + if systemctl status docker 2>/dev/null; then + systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 + fi fi + +# For containerd-based installers, we don't automatically update the +# configuration. If it uses a v2 shim, then it will find the package binaries +# automatically when provided the appropriate annotation. diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 1036b0630..05e3637f7 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -31,5 +31,6 @@ go_test( deps = [ "//pkg/log", "//pkg/p9", + "//pkg/test/testutil", ], ) diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 1dce36965..88814b83c 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -128,6 +128,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_MADVISE: {}, unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init(). syscall.SYS_MKDIRAT: {}, + syscall.SYS_MKNODAT: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. // diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index 74977c313..c6694c278 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -48,36 +48,6 @@ const ( openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC ) -type fileType int - -const ( - regular fileType = iota - directory - symlink - socket - unknown -) - -// String implements fmt.Stringer. -func (f fileType) String() string { - switch f { - case regular: - return "regular" - case directory: - return "directory" - case symlink: - return "symlink" - case socket: - return "socket" - } - return "unknown" -} - -// ControlSocketAddr generates an abstract unix socket name for the given id. -func ControlSocketAddr(id string) string { - return fmt.Sprintf("\x00runsc-gofer.%s", id) -} - // Config sets configuration options for each attach point. type Config struct { // ROMount is set to true if this is a readonly mount. @@ -132,19 +102,19 @@ func (a *attachPoint) Attach() (p9.File, error) { return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix) } - f, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { + f, readable, err := openAnyFile(a.prefix, func(mode int) (*fd.FD, error) { return fd.Open(a.prefix, openFlags|mode, 0) }) if err != nil { return nil, fmt.Errorf("unable to open %q: %v", a.prefix, err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return nil, fmt.Errorf("unable to stat %q: %v", a.prefix, err) } - lf, err := newLocalFile(a, f, a.prefix, stat) + lf, err := newLocalFile(a, f, a.prefix, readable, stat) if err != nil { return nil, fmt.Errorf("unable to create localFile %q: %v", a.prefix, err) } @@ -199,8 +169,6 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // entire file up when it's opened in write mode, and would perform badly when // multiple files are only being opened for read (esp. startup). type localFile struct { - p9.DefaultWalkGetAttr - // attachPoint is the attachPoint that serves this localFile. attachPoint *attachPoint @@ -212,12 +180,19 @@ type localFile struct { // opened with. file *fd.FD + // controlReadable tells whether 'file' was opened with read permissions + // during a walk. + controlReadable bool + // mode is the mode in which the file was opened. Set to invalidMode // if localFile isn't opened. mode p9.OpenFlags - // ft is the fileType for this file. - ft fileType + // fileType for this file. It is equivalent to: + // syscall.Stat_t.Mode & syscall.S_IFMT + fileType uint32 + + qid p9.QID // readDirMu protects against concurrent Readdir calls. readDirMu sync.Mutex @@ -251,83 +226,88 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { return fd.New(d), nil } -func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, error) { +func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) { path := path.Join(parent.hostPath, name) - f, err := openAnyFile(path, func(mode int) (*fd.FD, error) { + f, readable, err := openAnyFile(path, func(mode int) (*fd.FD, error) { return fd.OpenAt(parent.file, name, openFlags|mode, 0) }) - return f, path, err + return f, path, readable, err } // openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback // to O_PATH. 'path' is used for logging messages only. 'fn' is what does the // actual file open and is customizable by the caller. -func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error) { +func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) { // Attempt to open file in the following mode in order: // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. // Use non-blocking to prevent getting stuck inside open(2) for // FIFOs. This option has no effect on regular files. // 2. PATH: for symlinks, sockets. - modes := []int{syscall.O_RDONLY | syscall.O_NONBLOCK, unix.O_PATH} + options := []struct { + mode int + readable bool + }{ + { + mode: syscall.O_RDONLY | syscall.O_NONBLOCK, + readable: true, + }, + { + mode: unix.O_PATH, + readable: false, + }, + } var err error - var file *fd.FD - for i, mode := range modes { - file, err = fn(mode) + for i, option := range options { + var file *fd.FD + file, err = fn(option.mode) if err == nil { - // openat succeeded, we're done. - break + // Succeeded opening the file, we're done. + return file, option.readable, nil } switch e := extractErrno(err); e { case syscall.ENOENT: // File doesn't exist, no point in retrying. - return nil, e + return nil, false, e } - // openat failed. Try again with next mode, preserving 'err' in case this - // was the last attempt. - log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|mode, path, err) + // File failed to open. Try again with next mode, preserving 'err' in case + // this was the last attempt. + log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, path, err) } - if err != nil { - // All attempts to open file have failed, return the last error. - log.Debugf("Failed to open file, path: %q, err: %v", path, err) - return nil, extractErrno(err) - } - - return file, nil + // All attempts to open file have failed, return the last error. + log.Debugf("Failed to open file, path: %q, err: %v", path, err) + return nil, false, extractErrno(err) } -func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) { - var ft fileType +func checkSupportedFileType(stat syscall.Stat_t, permitSocket bool) error { switch stat.Mode & syscall.S_IFMT { - case syscall.S_IFREG: - ft = regular - case syscall.S_IFDIR: - ft = directory - case syscall.S_IFLNK: - ft = symlink + case syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK: + return nil + case syscall.S_IFSOCK: if !permitSocket { - return unknown, syscall.EPERM + return syscall.EPERM } - ft = socket + return nil + default: - return unknown, syscall.EPERM + return syscall.EPERM } - return ft, nil } -func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) { - ft, err := getSupportedFileType(stat, a.conf.HostUDS) - if err != nil { +func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat syscall.Stat_t) (*localFile, error) { + if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil { return nil, err } return &localFile{ - attachPoint: a, - hostPath: path, - file: file, - mode: invalidMode, - ft: ft, + attachPoint: a, + hostPath: path, + file: file, + mode: invalidMode, + fileType: stat.Mode & syscall.S_IFMT, + qid: a.makeQID(stat), + controlReadable: readable, }, nil } @@ -346,13 +326,13 @@ func newFDMaybe(file *fd.FD) *fd.FD { // fd is blocking; non-blocking is required. if err := syscall.SetNonblock(dup.FD(), true); err != nil { - dup.Close() + _ = dup.Close() return nil } return dup } -func stat(fd int) (syscall.Stat_t, error) { +func fstat(fd int) (syscall.Stat_t, error) { var stat syscall.Stat_t if err := syscall.Fstat(fd, &stat); err != nil { return syscall.Stat_t{}, err @@ -360,6 +340,14 @@ func stat(fd int) (syscall.Stat_t, error) { return stat, nil } +func stat(path string) (syscall.Stat_t, error) { + var stat syscall.Stat_t + if err := syscall.Stat(path, &stat); err != nil { + return syscall.Stat_t{}, err + } + return stat, nil +} + func fchown(fd int, uid p9.UID, gid p9.GID) error { return syscall.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) } @@ -372,7 +360,7 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if flags == p9.ReadOnly { + if flags == p9.ReadOnly && l.controlReadable { log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { @@ -388,16 +376,8 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } } - stat, err := stat(newFile.FD()) - if err != nil { - if newFile != l.file { - newFile.Close() - } - return nil, p9.QID{}, 0, extractErrno(err) - } - var fd *fd.FD - if stat.Mode&syscall.S_IFMT == syscall.S_IFREG { + if l.fileType == syscall.S_IFREG { // Donate FD for regular files only. fd = newFDMaybe(newFile) } @@ -410,7 +390,7 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { l.file = newFile } l.mode = flags & p9.OpenFlagsModeMask - return fd, l.attachPoint.makeQID(stat), 0, nil + return fd, l.qid, 0, nil } // Create implements p9.File. @@ -438,7 +418,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid return nil, nil, p9.QID{}, 0, extractErrno(err) } cu := cleanup.Make(func() { - child.Close() + _ = child.Close() // Best effort attempt to remove the file in case of failure. if err := syscall.Unlinkat(l.file.FD(), name); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) @@ -449,7 +429,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid if err := fchown(child.FD(), uid, gid); err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } - stat, err := stat(child.FD()) + stat, err := fstat(child.FD()) if err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } @@ -459,10 +439,12 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid hostPath: path.Join(l.hostPath, name), file: child, mode: mode, + fileType: syscall.S_IFREG, + qid: l.attachPoint.makeQID(stat), } cu.Release() - return newFDMaybe(c.file), c, l.attachPoint.makeQID(stat), 0, nil + return newFDMaybe(c.file), c, c.qid, 0, nil } // Mkdir implements p9.File. @@ -497,7 +479,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) if err := fchown(f.FD(), uid, gid); err != nil { return p9.QID{}, extractErrno(err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return p9.QID{}, extractErrno(err) } @@ -508,55 +490,74 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) // Walk implements p9.File. func (l *localFile) Walk(names []string) ([]p9.QID, p9.File, error) { + qids, file, _, err := l.walk(names) + return qids, file, err +} + +// WalkGetAttr implements p9.File. +func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, p9.Attr, error) { + qids, file, stat, err := l.walk(names) + if err != nil { + return nil, nil, p9.AttrMask{}, p9.Attr{}, err + } + mask, attr := l.fillAttr(stat) + return qids, file, mask, attr, nil +} + +func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, error) { // Duplicate current file if 'names' is empty. if len(names) == 0 { - newFile, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { + newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { return reopenProcFd(l.file, openFlags|mode) }) if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - stat, err := stat(newFile.FD()) + stat, err := fstat(newFile.FD()) if err != nil { - newFile.Close() - return nil, nil, extractErrno(err) + _ = newFile.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } c := &localFile{ - attachPoint: l.attachPoint, - hostPath: l.hostPath, - file: newFile, - mode: invalidMode, + attachPoint: l.attachPoint, + hostPath: l.hostPath, + file: newFile, + mode: invalidMode, + fileType: l.fileType, + qid: l.attachPoint.makeQID(stat), + controlReadable: readable, } - return []p9.QID{l.attachPoint.makeQID(stat)}, c, nil + return []p9.QID{c.qid}, c, stat, nil } var qids []p9.QID + var lastStat syscall.Stat_t last := l for _, name := range names { - f, path, err := openAnyFileFromParent(last, name) + f, path, readable, err := openAnyFileFromParent(last, name) if last != l { - last.Close() + _ = last.Close() } if err != nil { - return nil, nil, extractErrno(err) + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - stat, err := stat(f.FD()) + lastStat, err = fstat(f.FD()) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - c, err := newLocalFile(last.attachPoint, f, path, stat) + c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat) if err != nil { - f.Close() - return nil, nil, extractErrno(err) + _ = f.Close() + return nil, nil, syscall.Stat_t{}, extractErrno(err) } - qids = append(qids, l.attachPoint.makeQID(stat)) + qids = append(qids, c.qid) last = c } - return qids, last, nil + return qids, last, lastStat, nil } // StatFS implements p9.File. @@ -592,11 +593,15 @@ func (l *localFile) FSync() error { // GetAttr implements p9.File. func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) { - stat, err := stat(l.file.FD()) + stat, err := fstat(l.file.FD()) if err != nil { return p9.QID{}, p9.AttrMask{}, p9.Attr{}, extractErrno(err) } + mask, attr := l.fillAttr(stat) + return l.qid, mask, attr, nil +} +func (l *localFile) fillAttr(stat syscall.Stat_t) (p9.AttrMask, p9.Attr) { attr := p9.Attr{ Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), @@ -625,8 +630,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) MTime: true, CTime: true, } - - return l.attachPoint.makeQID(stat), valid, attr, nil + return valid, attr } // SetAttr implements p9.File. Due to mismatch in file API, options @@ -667,7 +671,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // Check if it's possible to use cached file, or if another one needs to be // opened for write. f := l.file - if l.ft == regular && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { + if l.fileType == syscall.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { var err error f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY) if err != nil { @@ -723,7 +727,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } } - if l.ft == symlink { + if l.fileType == syscall.S_IFLNK { // utimensat operates different that other syscalls. To operate on a // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty // name. @@ -880,7 +884,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. if err := fchown(f.FD(), uid, gid); err != nil { return p9.QID{}, extractErrno(err) } - stat, err := stat(f.FD()) + stat, err := fstat(f.FD()) if err != nil { return p9.QID{}, extractErrno(err) } @@ -907,13 +911,39 @@ func (l *localFile) Link(target p9.File, newName string) error { } // Mknod implements p9.File. -// -// Not implemented. -func (*localFile) Mknod(_ string, _ p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { +func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { + conf := l.attachPoint.conf + if conf.ROMount { + if conf.PanicOnWrite { + panic("attempt to write to RO mount") + } + return p9.QID{}, syscall.EROFS + } + + hostPath := path.Join(l.hostPath, name) + + // Return EEXIST if the file already exists. + if _, err := stat(hostPath); err == nil { + return p9.QID{}, syscall.EEXIST + } + // From mknod(2) man page: // "EPERM: [...] if the filesystem containing pathname does not support // the type of node requested." - return p9.QID{}, syscall.EPERM + if mode.FileType() != p9.ModeRegular { + return p9.QID{}, syscall.EPERM + } + + // Allow Mknod to create regular files. + if err := syscall.Mknod(hostPath, uint32(mode), 0); err != nil { + return p9.QID{}, err + } + + stat, err := stat(hostPath) + if err != nil { + return p9.QID{}, extractErrno(err) + } + return l.attachPoint.makeQID(stat), nil } // UnlinkAt implements p9.File. @@ -949,9 +979,12 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { skip := uint64(0) - // Check if the file is at the correct position already. If not, seek to the - // beginning and read the entire directory again. - if l.lastDirentOffset != offset { + // Check if the file is at the correct position already. If not, seek to + // the beginning and read the entire directory again. We always seek if + // offset is 0, since this is side-effectual (equivalent to rewinddir(3), + // which causes the directory stream to resynchronize with the directory's + // current contents). + if l.lastDirentOffset != offset || offset == 0 { if _, err := syscall.Seek(l.file.FD(), 0, 0); err != nil { return nil, extractErrno(err) } @@ -1079,13 +1112,13 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { } if err := syscall.SetNonblock(f, true); err != nil { - syscall.Close(f) + _ = syscall.Close(f) return nil, err } sa := syscall.SockaddrUnix{Name: l.hostPath} if err := syscall.Connect(f, &sa); err != nil { - syscall.Close(f) + _ = syscall.Close(f) return nil, err } diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 05af7e397..94f167417 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -26,6 +26,19 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} + +var ( + allTypes = []uint32{syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK} + + // allConfs is set in init(). + allConfs []Config + + rwConfs = []Config{{ROMount: false}} + roConfs = []Config{{ROMount: true}} ) func init() { @@ -39,6 +52,13 @@ func init() { } } +func configTestName(config *Config) string { + if config.ROMount { + return "ROMount" + } + return "RWMount" +} + func assertPanic(t *testing.T, f func()) { defer func() { if r := recover(); r == nil { @@ -88,71 +108,76 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { return nil } -var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} - -var ( - allTypes = []fileType{regular, directory, symlink} - - // allConfs is set in init() above. - allConfs []Config - - rwConfs = []Config{{ROMount: false}} - roConfs = []Config{{ROMount: true}} -) - type state struct { - root *localFile - file *localFile - conf Config - ft fileType + root *localFile + file *localFile + conf Config + fileType uint32 } func (s state) String() string { - return fmt.Sprintf("type(%v)", s.ft) + return fmt.Sprintf("type(%v)", s.fileType) +} + +func typeName(fileType uint32) string { + switch fileType { + case syscall.S_IFREG: + return "file" + case syscall.S_IFDIR: + return "directory" + case syscall.S_IFLNK: + return "symlink" + default: + panic(fmt.Sprintf("invalid file type for test: %d", fileType)) + } } func runAll(t *testing.T, test func(*testing.T, state)) { runCustom(t, allTypes, allConfs, test) } -func runCustom(t *testing.T, types []fileType, confs []Config, test func(*testing.T, state)) { +func runCustom(t *testing.T, types []uint32, confs []Config, test func(*testing.T, state)) { for _, c := range confs { - t.Logf("Config: %+v", c) - for _, ft := range types { - t.Logf("File type: %v", ft) + name := fmt.Sprintf("%s/%s", configTestName(&c), typeName(ft)) + t.Run(name, func(t *testing.T) { + path, name, err := setup(ft) + if err != nil { + t.Fatalf("%v", err) + } + defer os.RemoveAll(path) - path, name, err := setup(ft) - if err != nil { - t.Fatalf("%v", err) - } - defer os.RemoveAll(path) + a, err := NewAttachPoint(path, c) + if err != nil { + t.Fatalf("NewAttachPoint failed: %v", err) + } + root, err := a.Attach() + if err != nil { + t.Fatalf("Attach failed, err: %v", err) + } - a, err := NewAttachPoint(path, c) - if err != nil { - t.Fatalf("NewAttachPoint failed: %v", err) - } - root, err := a.Attach() - if err != nil { - t.Fatalf("Attach failed, err: %v", err) - } + _, file, err := root.Walk([]string{name}) + if err != nil { + root.Close() + t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) + } - _, file, err := root.Walk([]string{name}) - if err != nil { + st := state{ + root: root.(*localFile), + file: file.(*localFile), + conf: c, + fileType: ft, + } + test(t, st) + file.Close() root.Close() - t.Fatalf("root.Walk({%q}) failed, err: %v", "symlink", err) - } - - st := state{root: root.(*localFile), file: file.(*localFile), conf: c, ft: ft} - test(t, st) - file.Close() - root.Close() + }) } } } -func setup(ft fileType) (string, string, error) { - path, err := ioutil.TempDir("", "root-") +func setup(fileType uint32) (string, string, error) { + path, err := ioutil.TempDir(testutil.TmpDir(), "root-") if err != nil { return "", "", fmt.Errorf("ioutil.TempDir() failed, err: %v", err) } @@ -169,26 +194,26 @@ func setup(ft fileType) (string, string, error) { defer root.Close() var name string - switch ft { - case regular: + switch fileType { + case syscall.S_IFREG: name = "file" _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) } defer f.Close() - case directory: + case syscall.S_IFDIR: name = "dir" if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err) } - case symlink: + case syscall.S_IFLNK: name = "symlink" if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err) } default: - panic(fmt.Sprintf("unknown file type %v", ft)) + panic(fmt.Sprintf("unknown file type %v", fileType)) } return path, name, nil } @@ -202,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) { } func TestReadWrite(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -232,7 +257,7 @@ func TestReadWrite(t *testing.T) { } func TestCreate(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { for i, flags := range allOpenFlags { _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { @@ -249,7 +274,7 @@ func TestCreate(t *testing.T) { // TestReadWriteDup tests that a file opened in any mode can be dup'ed and // reopened in any other mode. func TestReadWriteDup(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -291,7 +316,7 @@ func TestReadWriteDup(t *testing.T) { } func TestUnopened(t *testing.T) { - runCustom(t, []fileType{regular}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFREG}, allConfs, func(t *testing.T, s state) { b := []byte("foobar") if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF { t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err) @@ -308,6 +333,32 @@ func TestUnopened(t *testing.T) { }) } +// TestOpenOPath is a regression test to ensure that a file that cannot be open +// for read is allowed to be open. This was happening because the control file +// was open with O_PATH, but Open() was not checking for it and allowing the +// control file to be reused. +func TestOpenOPath(t *testing.T) { + runCustom(t, []uint32{syscall.S_IFREG}, rwConfs, func(t *testing.T, s state) { + // Fist remove all permissions on the file. + if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil { + t.Fatalf("SetAttr(): %v", err) + } + // Then walk to the file again to open a new control file. + filename := filepath.Base(s.file.hostPath) + _, newFile, err := s.root.Walk([]string{filename}) + if err != nil { + t.Fatalf("root.Walk(%q): %v", filename, err) + } + + if newFile.(*localFile).controlReadable { + t.Fatalf("control file didn't open with O_PATH: %+v", newFile) + } + if _, _, _, err := newFile.Open(p9.ReadOnly); err != syscall.EACCES { + t.Fatalf("Open() should have failed, got: %v, wanted: EACCES", err) + } + }) +} + func SetGetAttr(l *localFile, valid p9.SetAttrMask, attr p9.SetAttr) (p9.Attr, error) { if err := l.SetAttr(valid, attr); err != nil { return p9.Attr{}, err @@ -324,7 +375,7 @@ func TestSetAttrPerm(t *testing.T) { valid := p9.SetAttrMask{Permissions: true} attr := p9.SetAttr{Permissions: 0777} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink { + if s.fileType == syscall.S_IFLNK { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -345,7 +396,7 @@ func TestSetAttrSize(t *testing.T) { valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: size} got, err := SetGetAttr(s.file, valid, attr) - if s.ft == symlink || s.ft == directory { + if s.fileType == syscall.S_IFLNK || s.fileType == syscall.S_IFDIR { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -427,7 +478,7 @@ func TestLink(t *testing.T) { } err = dir.Link(s.file, linkFile) - if s.ft == directory { + if s.fileType == syscall.S_IFDIR { if err != syscall.EPERM { t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err) } @@ -485,7 +536,7 @@ func TestROMountPanics(t *testing.T) { } func TestWalkNotFound(t *testing.T) { - runCustom(t, []fileType{directory}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, allConfs, func(t *testing.T, s state) { if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT { t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err) } @@ -506,7 +557,7 @@ func TestWalkDup(t *testing.T) { } func TestReaddir(t *testing.T) { - runCustom(t, []fileType{directory}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { name := "dir" if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err) diff --git a/runsc/main.go b/runsc/main.go index c9f47c579..69cb505fa 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -88,6 +88,7 @@ var ( referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") + fuseEnabled = flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") // Test flags, not to be used outside tests, ever. testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") @@ -242,6 +243,7 @@ func main() { OverlayfsStaleRead: *overlayfsStaleRead, CPUNumFromQuota: *cpuNumFromQuota, VFS2: *vfs2Enabled, + FUSE: *fuseEnabled, QDisc: queueingDiscipline, TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, TestOnlyTestNameEnv: *testOnlyTestNameEnv, diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 035dcd3e3..2b9d4549d 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -29,7 +29,7 @@ go_library( "//runsc/console", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@com_github_vishvananda_netlink//:go_default_library", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index deee619f3..817a923ad 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -134,7 +134,6 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG return err } if isRoot { - return fmt.Errorf("cannot run with network enabled in root network namespace") } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 6e1a2af25..2afcc27af 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -478,9 +478,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // If the console control socket file is provided, then create a new // pty master/slave pair and set the TTY on the sandbox process. - if args.ConsoleSocket != "" { - cmd.Args = append(cmd.Args, "--console=true") - + if args.Spec.Process.Terminal && args.ConsoleSocket != "" { // console.NewWithSocket will send the master on the given // socket, and return the slave. tty, err := console.NewWithSocket(args.ConsoleSocket) diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 62d4f5113..43851a22f 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -18,7 +18,7 @@ go_library( "//pkg/sentry/kernel/auth", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_mohae_deepcopy//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -29,5 +29,5 @@ go_test( size = "small", srcs = ["specutils_test.go"], library = ":specutils", - deps = ["@com_github_opencontainers_runtime-spec//specs-go:go_default_library"], + deps = ["@com_github_opencontainers_runtime_spec//specs-go:go_default_library"], ) diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh deleted file mode 100755 index e0f6df438..000000000 --- a/scripts/benchmark.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -source $(dirname $0)/common.sh - -# gcloud may be installed as a "snap". If it is, include it in PATH. -declare -r snap="/snap/bin" -if [[ -d "${snap}" ]]; then - export PATH="${PATH}:${snap}" -fi - -# Make sure we can find gcloud and exit if not. -which gcloud - -# Exporting for subprocesses as GCP APIs and tools check this environmental -# variable for authentication. -export GOOGLE_APPLICATION_CREDENTIALS="${KOKORO_KEYSTORE_DIR}/${GCLOUD_CREDENTIALS}" - -gcloud auth activate-service-account \ - --key-file "${GOOGLE_APPLICATION_CREDENTIALS}" - -gcloud config set project ${PROJECT} -gcloud config set compute/zone ${ZONE} - -bazel run //benchmarks:benchmarks -- \ - --verbose \ - run-gcp \ - "(startup|absl)" \ - --internal \ - --runtime=runc \ - --runtime=runsc \ - --installers=head diff --git a/scripts/common_build.sh b/scripts/common_build.sh index 0d9a191b5..d4a6c4908 100755 --- a/scripts/common_build.sh +++ b/scripts/common_build.sh @@ -31,9 +31,9 @@ declare -a BAZEL_FLAGS=( "--keep_going" "--verbose_failures=true" ) -if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then +# If running via kokoro, use the remote config. +if [[ -v KOKORO_ARTIFACTS_DIR ]]; then BAZEL_FLAGS+=( - "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}" "--config=remote" ) fi @@ -64,7 +64,7 @@ function run_as_root() { } function query() { - QUERY_RESULT=$(bazel query "$@") + bazel query "$@" } function collect_logs() { diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh index dce0a4085..07e9f3109 100755 --- a/scripts/docker_tests.sh +++ b/scripts/docker_tests.sh @@ -22,4 +22,6 @@ install_runsc_for_test docker test_runsc //test/image:image_test //test/e2e:integration_test install_runsc_for_test docker --vfs2 -test_runsc //test/image:image_test --test_filter=.*TestHelloWorld +IMAGE_FILTER="Hello|Httpd|Ruby|Stdio" +INTEGRATION_FILTER="LifeCycle|Pause|Connect|JobControl|Overlay|Exec|DirCreation/root" +test_runsc //test/e2e:integration_test //test/image:image_test --test_filter="${IMAGE_FILTER}|${INTEGRATION_FILTER}" diff --git a/scripts/issue_reviver.sh b/scripts/issue_reviver.sh deleted file mode 100755 index bac9b9192..000000000 --- a/scripts/issue_reviver.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -DIR=$(dirname $0) -source "${DIR}"/common.sh - -# Provide a credential file if available. -export OAUTH_TOKEN_FILE="" -if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then - OAUTH_TOKEN_FILE="${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}" -fi - -REPO_ROOT=$(cd "$(dirname "${DIR}")"; pwd) -run //tools/issue_reviver:issue_reviver --path "${REPO_ROOT}" --oauth-token-file="${OAUTH_TOKEN_FILE}" diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh index 727503bce..1a8181ac8 100755 --- a/scripts/packetdrill_tests.sh +++ b/scripts/packetdrill_tests.sh @@ -19,5 +19,5 @@ source $(dirname $0)/common.sh make load-packetdrill install_runsc_for_test runsc-d -query "attr(tags, manual, tests(//test/packetdrill/...))" +QUERY_RESULT=$(query "attr(tags, manual, tests(//test/packetdrill/...))") test_runsc $QUERY_RESULT diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh index 51c11f23f..77fb84bc3 100755 --- a/scripts/packetimpact_tests.sh +++ b/scripts/packetimpact_tests.sh @@ -19,5 +19,5 @@ source $(dirname $0)/common.sh make load-packetimpact install_runsc_for_test runsc-d -query "attr(tags, packetimpact, tests(//test/packetimpact/...))" +QUERY_RESULT=$(query "attr(tags, packetimpact, tests(//test/packetimpact/...))") test_runsc $QUERY_RESULT diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh index d629bf2aa..3eb735e62 100755 --- a/scripts/root_tests.sh +++ b/scripts/root_tests.sh @@ -17,15 +17,8 @@ source $(dirname $0)/common.sh make load-all-images - -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim +CONTAINERD_VERSION=1.3.4 make sudo TARGETS="tools/installers:containerd" +make sudo TARGETS="tools/installers:shim" # Run the tests that require root. install_runsc_for_test root diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh index 350a59f7c..85e95d45d 100755 --- a/scripts/runtime_tests.sh +++ b/scripts/runtime_tests.sh @@ -22,5 +22,8 @@ if [ ! -v RUNTIME_TEST_NAME ]; then exit 1 fi +# Download language runtime image. +make -C images/ "load-runtimes_${RUNTIME_TEST_NAME}" + install_runsc_for_test runtimes -test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test" +test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}" diff --git a/shim/BUILD b/shim/BUILD new file mode 100644 index 000000000..e581618b2 --- /dev/null +++ b/shim/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "pkg_tar") + +package(licenses = ["notice"]) + +pkg_tar( + name = "config", + srcs = [ + "runsc.toml", + ], + mode = "0644", + package_dir = "/etc/containerd", + visibility = [ + "//runsc:__pkg__", + ], +) diff --git a/shim/README.md b/shim/README.md new file mode 100644 index 000000000..c6824ebdc --- /dev/null +++ b/shim/README.md @@ -0,0 +1,11 @@ +# gVisor Containerd shims + +There are various shims supported for differt versions of +[containerd][containerd]. + +- [Configure gvisor-containerd-shim (shim v1) (containerd ≤ 1.2)](v1/configure-gvisor-containerd-shim.md) +- [Runtime Handler/RuntimeClass Quick Start (containerd >= 1.2)](v2/runtime-handler-quickstart.md) +- [Runtime Handler/RuntimeClass Quick Start (shim v2) (containerd >= 1.2)](v2/runtime-handler-shim-v2-quickstart.md) +- [Configure containerd-shim-runsc-v1 (shim v2) (containerd >= 1.3)](v2/configure-containerd-shim-runsc-v1.md) + +[containerd]: https://github.com/containerd/containerd diff --git a/shim/runsc.toml b/shim/runsc.toml new file mode 100644 index 000000000..e1c7de1bb --- /dev/null +++ b/shim/runsc.toml @@ -0,0 +1,6 @@ +# This is an example configuration file for runsc. +# +# By default, it will be parsed from /etc/containerd/runsc.toml, but see the +# static path configured in v1/main.go. Note that the configuration mechanism +# for newer container shim versions is different: see the documentation in v2. +[runsc_config] diff --git a/shim/v1/BUILD b/shim/v1/BUILD new file mode 100644 index 000000000..7b837630c --- /dev/null +++ b/shim/v1/BUILD @@ -0,0 +1,41 @@ +load("//tools:defs.bzl", "go_binary") +load("//website:defs.bzl", "doc") + +package(licenses = ["notice"]) + +go_binary( + name = "gvisor-containerd-shim", + srcs = [ + "api.go", + "config.go", + "main.go", + ], + static = True, + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/shim", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_ttrpc//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +doc( + name = "doc", + src = "README.md", + category = "User Guide", + permalink = "/docs/user_guide/gvisor-containerd-shim/", + subcategory = "Advanced", + visibility = ["//website:__pkg__"], + weight = "93", +) diff --git a/shim/v1/README.md b/shim/v1/README.md new file mode 100644 index 000000000..7aa4513a1 --- /dev/null +++ b/shim/v1/README.md @@ -0,0 +1,50 @@ +# gvisor-containerd-shim + +> Note: This shim version is supported only for containerd versions less than +> 1.2. If you are using a containerd version greater than or equal to 1.2, then +> please use `containerd-shim-runsc-v1` (Shim API v1). +> +> This containerd shim is supported only in a best-effort capacity. + +This document describes how to configure and use `gvisor-containerd-shim`. + +## Containerd Configuration + +To use this shim, you must configure `/etc/containerd/config.toml` as follows: + +``` +[plugins.linux] + shim = "/usr/bin/gvisor-containerd-shim" +[plugins.cri.containerd.runtimes.gvisor] + runtime_type = "io.containerd.runtime.v1.linux" + runtime_engine = "/usr/bin/runsc" + runtime_root = "/run/containerd/runsc" +``` + +In order to pick-up the new configuration, you may need to restart containerd: + +```shell +sudo systemctl restart containerd +``` + +## Shim Confguration + +The shim configuration is stored in `/etc/containerd/runsc.toml`. The +configuration file supports two values. + +* `runc_shim`: The path to the runc shim. This is used by + `gvisor-containerd-shim` to run standard containers. + +* `runsc_config`: This is a set of key/value pairs that are converted into + `runsc` command line flags. You can learn more about which flags are + available by running `runsc flags`. + +For example, a configuration might look as follows: + +``` +runc_shim = "/usr/local/bin/containerd-shim" +[runsc_config] +platform = "kvm" +debug = true +debug-log = /var/log/%ID%/gvisor/ +``` diff --git a/shim/v1/api.go b/shim/v1/api.go new file mode 100644 index 000000000..2444d23f1 --- /dev/null +++ b/shim/v1/api.go @@ -0,0 +1,24 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + shim "github.com/containerd/containerd/runtime/v1/shim/v1" +) + +type KillRequest = shim.KillRequest + +var registerShimService = shim.RegisterShimService diff --git a/shim/v1/config.go b/shim/v1/config.go new file mode 100644 index 000000000..a72cc7754 --- /dev/null +++ b/shim/v1/config.go @@ -0,0 +1,40 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "github.com/BurntSushi/toml" + +// config is the configuration for gvisor containerd shim. +type config struct { + // RuncShim is the shim binary path for standard containerd-shim for runc. + // When the runtime is `runc`, gvisor containerd shim will exec current + // process to standard containerd-shim. This is a work around for containerd + // 1.1. In containerd 1.2, containerd will choose different containerd-shims + // based on runtime. + RuncShim string `toml:"runc_shim"` + // RunscConfig is configuration for runsc. The key value will be converted + // to runsc flags --key=value directly. + RunscConfig map[string]string `toml:"runsc_config"` +} + +// loadConfig load gvisor containerd shim config from config file. +func loadConfig(path string) (*config, error) { + var c config + _, err := toml.DecodeFile(path, &c) + if err != nil { + return &c, err + } + return &c, nil +} diff --git a/shim/v1/main.go b/shim/v1/main.go new file mode 100644 index 000000000..3159923af --- /dev/null +++ b/shim/v1/main.go @@ -0,0 +1,265 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/sys" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/ttrpc" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/shim" +) + +var ( + debugFlag bool + namespaceFlag string + socketFlag string + addressFlag string + workdirFlag string + runtimeRootFlag string + containerdBinaryFlag string + shimConfigFlag string +) + +// Containerd defaults to runc, unless another runtime is explicitly specified. +// We keep the same default to make the default behavior consistent. +const defaultRoot = "/run/containerd/runc" + +func init() { + flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") + flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") + flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") + flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") + flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime") + + // Currently, the `containerd publish` utility is embedded in the + // daemon binary. The daemon invokes `containerd-shim + // -containerd-binary ...` with its own os.Executable() path. + flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)") + flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file") +} + +func main() { + flag.Parse() + + // This is a hack. Exec current process to run standard containerd-shim + // if runtime root is not `runsc`. We don't need this for shim v2 api. + if filepath.Base(runtimeRootFlag) != "runsc" { + if err := executeRuncShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } + } + + // Run regular shim if needed. + if err := executeShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } +} + +// executeRuncShim execs current process to a containerd-shim process and +// retains all flags and envs. +func executeRuncShim() error { + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + shimPath := c.RuncShim + if shimPath == "" { + shimPath, err = exec.LookPath("containerd-shim") + if err != nil { + return fmt.Errorf("lookup containerd-shim failed: %w", err) + } + } + + args := append([]string{shimPath}, os.Args[1:]...) + if err := syscall.Exec(shimPath, args, os.Environ()); err != nil { + return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err) + } + return nil +} + +func executeShim() error { + // start handling signals as soon as possible so that things are + // properly reaped or if runtime exits before we hit the handler. + signals, err := setupSignals() + if err != nil { + return err + } + path, err := os.Getwd() + if err != nil { + return err + } + server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) + if err != nil { + return fmt.Errorf("failed creating server: %w", err) + } + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + sv, err := shim.NewService( + shim.Config{ + Path: path, + Namespace: namespaceFlag, + WorkDir: workdirFlag, + RuntimeRoot: runtimeRootFlag, + RunscConfig: c.RunscConfig, + }, + &remoteEventsPublisher{address: addressFlag}, + ) + if err != nil { + return err + } + registerShimService(server, sv) + if err := serve(server, socketFlag); err != nil { + return err + } + return handleSignals(signals, server, sv) +} + +// serve serves the ttrpc API over a unix socket at the provided path this +// function does not block. +func serve(server *ttrpc.Server, path string) error { + var ( + l net.Listener + err error + ) + if path == "" { + l, err = net.FileListener(os.NewFile(3, "socket")) + path = "[inherited from parent]" + } else { + if len(path) > 106 { + return fmt.Errorf("%q: unix socket path too long (> 106)", path) + } + l, err = net.Listen("unix", "\x00"+path) + } + if err != nil { + return err + } + go func() { + defer l.Close() + err := server.Serve(context.Background(), l) + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + log.Fatalf("ttrpc server failure: %v", err) + } + }() + return nil +} + +// setupSignals creates a new signal handler for all signals and sets the shim +// as a sub-reaper so that the container processes are reparented. +func setupSignals() (chan os.Signal, error) { + signals := make(chan os.Signal, 32) + signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE) + // make sure runc is setup to use the monitor for waiting on processes. + // TODO(random-liu): Move shim/reaper.go to a separate package. + runsc.Monitor = reaper.Default + // Set the shim as the subreaper for all orphaned processes created by + // the container. + if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { + return nil, err + } + return signals, nil +} + +func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error { + var ( + termOnce sync.Once + done = make(chan struct{}) + ) + + for { + select { + case <-done: + return nil + case s := <-signals: + switch s { + case unix.SIGCHLD: + if _, err := sys.Reap(false); err != nil { + log.Printf("reap error: %v", err) + } + case unix.SIGTERM, unix.SIGINT: + go termOnce.Do(func() { + ctx := context.TODO() + if err := server.Shutdown(ctx); err != nil { + log.Printf("failed to shutdown server: %v", err) + } + // Ensure our child is dead if any. + sv.Kill(ctx, &KillRequest{ + Signal: uint32(syscall.SIGKILL), + All: true, + }) + sv.Delete(context.Background(), &types.Empty{}) + close(done) + }) + case unix.SIGPIPE: + } + } + } +} + +type remoteEventsPublisher struct { + address string +} + +func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { + ns, _ := namespaces.Namespace(ctx) + encoded, err := typeurl.MarshalAny(event) + if err != nil { + return err + } + data, err := encoded.Marshal() + if err != nil { + return err + } + cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns) + cmd.Stdin = bytes.NewReader(data) + c, err := reaper.Default.Start(cmd) + if err != nil { + return err + } + status, err := reaper.Default.Wait(cmd, c) + if err != nil { + return fmt.Errorf("failed to publish event: %w", err) + } + if status != 0 { + return fmt.Errorf("failed to publish event: status %d", status) + } + return nil +} diff --git a/shim/v2/BUILD b/shim/v2/BUILD new file mode 100644 index 000000000..ae4705935 --- /dev/null +++ b/shim/v2/BUILD @@ -0,0 +1,29 @@ +load("//tools:defs.bzl", "go_binary") +load("//website:defs.bzl", "doc") + +package(licenses = ["notice"]) + +go_binary( + name = "containerd-shim-runsc-v1", + srcs = [ + "main.go", + ], + static = True, + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/shim/v2", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + ], +) + +doc( + name = "doc", + src = "README.md", + category = "User Guide", + permalink = "/docs/user_guide/containerd-shim-runsc-v1/", + subcategory = "Advanced", + visibility = ["//website:__pkg__"], + weight = "92", +) diff --git a/shim/v2/README.md b/shim/v2/README.md new file mode 100644 index 000000000..2aa7c21e3 --- /dev/null +++ b/shim/v2/README.md @@ -0,0 +1,91 @@ +# containerd-shim-runsc-v1 + +> Note: This shim version is the recommended shim for containerd versions +> greater than or equal to 1.2. For older versions of containerd, use +> `gvisor-containerd-shim`. + +This document describes how to configure and use `containerd-shim-runsc-v1`. + +## Configuring Containerd 1.2 + +To configure containerd 1.2 to use this shim, add the runtime to +`/etc/containerd/config.toml` as follows: + +``` +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runsc.v1" + runtime_root = "/run/containerd/runsc" +[plugins.cri.containerd.runtimes.runsc.options] + TypeUrl = "io.containerd.runsc.v1.options" +``` + +The configuration will optionally loaded from a file named `config.toml` in the +`runtime_root` configured above. + +In order to pick up the new configuration, you may need to restart containerd: + +```shell +sudo systemctl restart containerd +``` + +## Configuring Containerd 1.3 and above + +To configure containerd 1.3 to use this shim, add the runtime to +`/etc/containerd/config.toml` as follows: + +``` +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runsc.v1" +[plugins.cri.containerd.runtimes.runsc.options] + TypeUrl = "io.containerd.runsc.v1.options" + ConfigPath = "/etc/containerd/runsc.toml" +``` + +The `ConfigPath` above will be used to provide a pointer to the configuration +file to be loaded. + +> Note that there will be configuration file loaded if `ConfigPath` is not set. + +In order to pick up the new configuration, you may need to restart containerd: + +```shell +sudo systemctl restart containerd +``` + +## Shim Confguration + +The shim configuration may carry the following options: + +* `shim_cgroup`: The cgroup to use for the shim itself. +* `io_uid`: The UID to use for pipes. +* `ui_gid`: The GID to use for pipes. +* `binary_name`: The runtime binary name (defaults to `runsc`). +* `root`: The root directory for the runtime. +* `runsc_config`: A dictionary of key-value pairs that will be passed to the + runtime as arguments. + +### Example: Enable the KVM platform + +gVisor enables the use of a number of platforms. This example shows how to +configure `containerd-shim-runsc-v1` to use gVisor with the KVM platform: + +```shell +cat <<EOF | sudo tee /etc/containerd/runsc.toml +[runsc_config] +platform = "kvm" +EOF +``` + +### Example: Enable gVisor debug logging + +gVisor debug logging can be enabled by setting the `debug` and `debug-log` flag. +The shim will replace "%ID%" with the container ID in the path of the +`debug-log` flag. + +```shell +cat <<EOF | sudo tee /etc/containerd/runsc.toml +[runsc_config] +debug = true +debug-log = /var/log/%ID%/gvisor.log +EOF +``` diff --git a/shim/v2/main.go b/shim/v2/main.go new file mode 100644 index 000000000..753871eea --- /dev/null +++ b/shim/v2/main.go @@ -0,0 +1,26 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "github.com/containerd/containerd/runtime/v2/shim" + + "gvisor.dev/gvisor/pkg/shim/v2" +) + +func main() { + shim.Run("io.containerd.runsc.v1", v2.New) +} diff --git a/shim/v2/runtime-handler-shim-v2-quickstart.md b/shim/v2/runtime-handler-shim-v2-quickstart.md new file mode 100644 index 000000000..3b88ca74b --- /dev/null +++ b/shim/v2/runtime-handler-shim-v2-quickstart.md @@ -0,0 +1,251 @@ +# Runtime Handler Quickstart (Shim V2) + +This document describes how to install and run `containerd-shim-runsc-v1` using +the containerd runtime handler support. This requires containerd 1.2 or later. + +## Requirements + +- **runsc**: See the [gVisor documentation](https://github.com/google/gvisor) + for information on how to install runsc. +- **containerd**: See the [containerd website](https://containerd.io/) for + information on how to install containerd. + +## Install + +### Install containerd-shim-runsc-v1 + +1. Build and install `containerd-shim-runsc-v1`. + +<!-- TODO: Use a release once we have one available. --> + +[embedmd]:# (../test/e2e/shim-install.sh shell /{ # Step 1\(dev\)/ /^}/) + +```shell +{ # Step 1(dev): Build and install gvisor-containerd-shim and containerd-shim-runsc-v1 + make + sudo make install +} +``` + +### Configure containerd + +1. Update `/etc/containerd/config.toml`. Make sure `containerd-shim-runsc-v1` + is in `${PATH}`. + +[embedmd]:# (../test/e2e/runtime-handler-shim-v2/install.sh shell /{ # Step 1/ /^}/) + +```shell +{ # Step 1: Create containerd config.toml +cat <<EOF | sudo tee /etc/containerd/config.toml +disabled_plugins = ["restart"] +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.runsc] + runtime_type = "io.containerd.runsc.v1" +EOF +} +``` + +1. Restart `containerd` + +```shell +sudo systemctl restart containerd +``` + +## Usage + +You can run containers in gVisor via containerd's CRI. + +### Install crictl + +1. Download and install the crictl binary: + +[embedmd]:# (../test/e2e/crictl-install.sh shell /{ # Step 1/ /^}/) + +```shell +{ # Step 1: Download crictl +wget https://github.com/kubernetes-sigs/cri-tools/releases/download/v1.13.0/crictl-v1.13.0-linux-amd64.tar.gz +tar xf crictl-v1.13.0-linux-amd64.tar.gz +sudo mv crictl /usr/local/bin +} +``` + +1. Write the crictl configuration file + +[embedmd]:# (../test/e2e/crictl-install.sh shell /{ # Step 2/ /^}/) + +```shell +{ # Step 2: Configure crictl +cat <<EOF | sudo tee /etc/crictl.yaml +runtime-endpoint: unix:///run/containerd/containerd.sock +EOF +} +``` + +### Create the nginx Sandbox in gVisor + +1. Pull the nginx image + +[embedmd]:# (../test/e2e/runtime-handler/usage.sh shell /{ # Step 1/ /^}/) + +```shell +{ # Step 1: Pull the nginx image +sudo crictl pull nginx +} +``` + +1. Create the sandbox creation request + +[embedmd]:# (../test/e2e/runtime-handler/usage.sh shell /{ # Step 2/ /^EOF\n}/) + +```shell +{ # Step 2: Create sandbox.json +cat <<EOF | tee sandbox.json +{ + "metadata": { + "name": "nginx-sandbox", + "namespace": "default", + "attempt": 1, + "uid": "hdishd83djaidwnduwk28bcsb" + }, + "linux": { + }, + "log_directory": "/tmp" +} +EOF +} +``` + +1. Create the pod in gVisor + +[embedmd]:# (../test/e2e/runtime-handler/usage.sh shell /{ # Step 3/ /^}/) + +```shell +{ # Step 3: Create the sandbox +SANDBOX_ID=$(sudo crictl runp --runtime runsc sandbox.json) +} +``` + +### Run the nginx Container in the Sandbox + +1. Create the nginx container creation request + +[embedmd]:# (../test/e2e/run-container.sh shell /{ # Step 1/ /^EOF\n}/) + +```shell +{ # Step 1: Create nginx container config +cat <<EOF | tee container.json +{ + "metadata": { + "name": "nginx" + }, + "image":{ + "image": "nginx" + }, + "log_path":"nginx.0.log", + "linux": { + } +} +EOF +} +``` + +1. Create the nginx container + +[embedmd]:# (../test/e2e/run-container.sh shell /{ # Step 2/ /^}/) + +```shell +{ # Step 2: Create nginx container +CONTAINER_ID=$(sudo crictl create ${SANDBOX_ID} container.json sandbox.json) +} +``` + +1. Start the nginx container + +[embedmd]:# (../test/e2e/run-container.sh shell /{ # Step 3/ /^}/) + +```shell +{ # Step 3: Start nginx container +sudo crictl start ${CONTAINER_ID} +} +``` + +### Validate the container + +1. Inspect the created pod + +[embedmd]:# (../test/e2e/validate.sh shell /{ # Step 1/ /^}/) + +```shell +{ # Step 1: Inspect the pod +sudo crictl inspectp ${SANDBOX_ID} +} +``` + +1. Inspect the nginx container + +[embedmd]:# (../test/e2e/validate.sh shell /{ # Step 2/ /^}/) + +```shell +{ # Step 2: Inspect the container +sudo crictl inspect ${CONTAINER_ID} +} +``` + +1. Verify that nginx is running in gVisor + +[embedmd]:# (../test/e2e/validate.sh shell /{ # Step 3/ /^}/) + +```shell +{ # Step 3: Check dmesg +sudo crictl exec ${CONTAINER_ID} dmesg | grep -i gvisor +} +``` + +### Set up the Kubernetes Runtime Class + +1. Install the Runtime Class for gVisor + +[embedmd]:# (../test/e2e/runtimeclass-install.sh shell /{ # Step 1/ /^}/) + +```shell +{ # Step 1: Install a RuntimeClass +cat <<EOF | kubectl apply -f - +apiVersion: node.k8s.io/v1beta1 +kind: RuntimeClass +metadata: + name: gvisor +handler: runsc +EOF +} +``` + +1. Create a Pod with the gVisor Runtime Class + +[embedmd]:# (../test/e2e/runtimeclass-install.sh shell /{ # Step 2/ /^}/) + +```shell +{ # Step 2: Create a pod +cat <<EOF | kubectl apply -f - +apiVersion: v1 +kind: Pod +metadata: + name: nginx-gvisor +spec: + runtimeClassName: gvisor + containers: + - name: nginx + image: nginx +EOF +} +``` + +1. Verify that the Pod is running + +[embedmd]:# (../test/e2e/runtimeclass-install.sh shell /{ # Step 3/ /^}/) + +```shell +{ # Step 3: Get the pod +kubectl get pod nginx-gvisor -o wide +} +``` diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md new file mode 100644 index 000000000..d1bbabf6f --- /dev/null +++ b/test/benchmarks/README.md @@ -0,0 +1,157 @@ +# Benchmark tools + +This package and subpackages are for running macro benchmarks on `runsc`. They +are meant to replace the previous //benchmarks benchmark-tools written in +python. + +Benchmarks are meant to look like regular golang benchmarks using the testing.B +library. + +## Setup + +To run benchmarks you will need: + +* Docker installed (17.09.0 or greater). + +The easiest way to setup runsc for running benchmarks is to use the make file. +From the root directory: + +* Download images: `make load-all-images` +* Install runsc suitable for benchmarking, which should probably not have + strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc + ARGS=--platform=kvm`. +* Restart docker: `sudo service docker restart` + +You should now have a runtime with the following options configured in +`/etc/docker/daemon.json` + +``` +"myrunsc": { + "path": "/tmp/myrunsc/runsc", + "runtimeArgs": [ + "--debug-log", + "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%", + "--platform=kvm" + ] + }, + +``` + +This runtime has been configured with a debugging off and strace logs off and is +using kvm for demonstration. + +## Running benchmarks + +Given the runtime above runtime `myrunsc`, run benchmarks with the following: + +``` +make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \ + -test.bench=." OPTIONS="-c opt +``` + +For example, to run only the Iperf tests: + +``` +make sudo TARGETS=//test/benchmarks/network:network_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt" +``` + +Benchmarks are run with root as some benchmarks require root privileges to do +things like drop caches. + +## Writing benchmarks + +Benchmarks consist of docker images as Dockerfiles and golang testing.B +benchmarks. + +### Dockerfiles: + +* Are stored at //images. +* New Dockerfiles go in an appropriately named directory at + `//images/benchmarks/my-cool-dockerfile`. +* Dockerfiles for benchmarks should: + * Use explicitly versioned packages. + * Not use ENV and CMD statements...it is easy to add these in the API. +* Note: A common pattern for getting access to a tmpfs mount is to copy files + there after container start. See: //test/benchmarks/build/bazel_test.go. You + can also make your own with `RunOpts.Mounts`. + +### testing.B packages + +In general, benchmarks should look like this: + +```golang + +var h harness.Harness + +func BenchmarkMyCoolOne(b *testing.B) { + machine, err := h.GetMachine() + // check err + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + + //Respect b.N. + for i := 0; i < b.N; i++ { + out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/my-cool-image", + Env: []string{"MY_VAR=awesome"}, + other options...see dockerutil + }, "sh", "-c", "echo MY_VAR") + //check err + b.StopTimer() + + // Do parsing and reporting outside of the timer. + number := parseMyMetric(out) + b.ReportMetric(number, "my-cool-custom-metric") + + b.StartTimer() + } +} + +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} +``` + +Some notes on the above: + +* The harness is initiated in the TestMain method and made global to test + module. The harness will handle any presetup that needs to happen with + flags, remote virtual machines (eventually), and other services. +* Respect `b.N` in that users of the benchmark may want to "run for an hour" + or something of the sort. +* Use the `b.ReportMetric()` method to report custom metrics. +* Set the timer if time is useful for reporting. There isn't a way to turn off + default metrics in testing.B (B/op, allocs/op, ns/op). +* Take a look at dockerutil at //pkg/test/dockerutil to see all methods + available from containers. The API is based on the "official" + [docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker). +* `harness.GetMachine()` marks how many machines this tests needs. If you have + a client and server and to mark them as multiple machines, call + `harness.GetMachine()` twice. + +## Profiling + +For profiling, the runtime is required to have the `--profile` flag enabled. +This flag loosens seccomp filters so that the runtime can write profile data to +disk. This configuration is not recommended for production. + +* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc + ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not + required, but are included for demonstration. +* Restart docker: `sudo service docker restart` + +To run and generate CPU profiles fs_test test run: + +``` +make sudo TARGETS=//test/benchmarks/fs:fs_test \ + ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt" +``` + +Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof` diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD new file mode 100644 index 000000000..5e33465cd --- /dev/null +++ b/test/benchmarks/database/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "database", + testonly = 1, + srcs = ["database.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "database_test", + size = "enormous", + srcs = [ + "redis_test.go", + ], + library = ":database", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go new file mode 100644 index 000000000..9eeb59f9a --- /dev/null +++ b/test/benchmarks/database/database.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package database holds benchmarks around database applications. +package database + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package database. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go new file mode 100644 index 000000000..6d39f4d66 --- /dev/null +++ b/test/benchmarks/database/redis_test.go @@ -0,0 +1,197 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// All possible operations from redis. Note: "ping" will +// run both PING_INLINE and PING_BUILD. +var operations []string = []string{ + "PING_INLINE", + "PING_BULK", + "SET", + "GET", + "INCR", + "LPUSH", + "RPUSH", + "LPOP", + "RPOP", + "SADD", + "HSET", + "SPOP", + "LRANGE_100", + "LRANGE_300", + "LRANGE_500", + "LRANGE_600", + "MSET", +} + +// BenchmarkRedis runs redis-benchmark against a redis instance and reports +// data in queries per second. Each is reported by named operation (e.g. LPUSH). +func BenchmarkRedis(b *testing.B) { + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Redis runs on port 6379 by default. + port := 6379 + ctx := context.Background() + + for _, operation := range operations { + b.Run(operation, func(b *testing.B) { + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // The redis docker container takes no arguments to run a redis server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + Ports: []int{port}, + }); err != nil { + b.Fatalf("failed to start redis server with: %v", err) + } + + if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + serverPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatal("failed to get IP from server: %v", err) + } + + if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { + b.Fatalf("failed to start redis with: %v", err) + } + + // runs redis benchmark -t operation for 100K requests against server. + cmd := strings.Split( + fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", operation, ip, serverPort), " ") + + // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case. + // Note that "ping" will run both PING_INLINE and PING_BULK. + if operation == "PING_BULK" { + cmd = strings.Split( + fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, serverPort), " ") + } + // Reset profiles and timer to begin the measurement. + server.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }, cmd...) + if err != nil { + b.Fatalf("redis-benchmark failed with: %v", err) + } + + // Stop time while we parse results. + b.StopTimer() + result, err := parseOperation(operation, out) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", out, err) + } + b.ReportMetric(result, operation) // operations per second + b.StartTimer() + } + }) + } +} + +// parseOperation grabs the metric operations per second from redis-benchmark output. +func parseOperation(operation, data string) (float64, error) { + re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, operation)) + match := re.FindStringSubmatch(data) + // If no match, simply don't add it to the result map. + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find %s in %s", operation, data) + } + return strconv.ParseFloat(match[2], 64) +} + +// TestParser tests the parser on sample data. +func TestParser(t *testing.T) { + sampleData := ` + "PING_INLINE","48661.80" + "PING_BULK","50301.81" + "SET","48923.68" + "GET","49382.71" + "INCR","49975.02" + "LPUSH","49875.31" + "RPUSH","50276.52" + "LPOP","50327.12" + "RPOP","50556.12" + "SADD","49504.95" + "HSET","49504.95" + "SPOP","50025.02" + "LPUSH (needed to benchmark LRANGE)","48875.86" + "LRANGE_100 (first 100 elements)","33955.86" + "LRANGE_300 (first 300 elements)","16550.81" + "LRANGE_500 (first 450 elements)","13653.74" + "LRANGE_600 (first 600 elements)","11219.57" + "MSET (10 keys)","44682.75" + ` + wants := map[string]float64{ + "PING_INLINE": 48661.80, + "PING_BULK": 50301.81, + "SET": 48923.68, + "GET": 49382.71, + "INCR": 49975.02, + "LPUSH": 49875.31, + "RPUSH": 50276.52, + "LPOP": 50327.12, + "RPOP": 50556.12, + "SADD": 49504.95, + "HSET": 49504.95, + "SPOP": 50025.02, + "LRANGE_100": 33955.86, + "LRANGE_300": 16550.81, + "LRANGE_500": 13653.74, + "LRANGE_600": 11219.57, + "MSET": 44682.75, + } + for op, want := range wants { + if got, err := parseOperation(op, sampleData); err != nil { + t.Fatalf("failed to parse %s: %v", op, err) + } else if want != got { + t.Fatalf("wanted %f for op %s, got %f", want, op, got) + } + } +} diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD new file mode 100644 index 000000000..79327b57c --- /dev/null +++ b/test/benchmarks/fs/BUILD @@ -0,0 +1,30 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "fs", + testonly = 1, + srcs = ["fs.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "fs_test", + size = "large", + srcs = [ + "bazel_test.go", + "fio_test.go", + ], + library = ":fs", + tags = [ + # Requires docker and runsc to be configured before test runs. + "local", + "manual", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], +) diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go new file mode 100644 index 000000000..f4236ba37 --- /dev/null +++ b/test/benchmarks/fs/bazel_test.go @@ -0,0 +1,119 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "fmt" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// Note: CleanCache versions of this test require running with root permissions. +func BenchmarkBuildABSL(b *testing.B) { + runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...") +} + +// Note: CleanCache versions of this test require running with root permissions. +// Note: This test takes on the order of 10m per permutation for runsc on kvm. +func BenchmarkBuildRunsc(b *testing.B) { + runBuildBenchmark(b, "benchmarks/runsc", "/gvisor", "runsc:runsc") +} + +func runBuildBenchmark(b *testing.B, image, workdir, target string) { + b.Helper() + // Get a machine from the Harness on which to run. + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + // Dimensions here are clean/dirty cache (do or don't drop caches) + // and if the mount on which we are compiling is a tmpfs/bind mount. + benchmarks := []struct { + name string + clearCache bool // clearCache drops caches before running. + tmpfs bool // tmpfs will run compilation on a tmpfs. + }{ + {name: "CleanCache", clearCache: true, tmpfs: false}, + {name: "DirtyCache", clearCache: false, tmpfs: false}, + {name: "CleanCacheTmpfs", clearCache: true, tmpfs: true}, + {name: "DirtyCacheTmpfs", clearCache: false, tmpfs: true}, + } + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + // Grab a container. + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Start a container and sleep by an order of b.N. + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: image, + }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { + b.Fatalf("run failed with: %v", err) + } + + // If we are running on a tmpfs, copy to /tmp which is a tmpfs. + if bm.tmpfs { + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + "cp", "-r", workdir, "/tmp/."); err != nil { + b.Fatal("failed to copy directory: %v %s", err, out) + } + workdir = "/tmp" + workdir + } + + // Restart profiles after the copy. + container.RestartProfiles() + b.ResetTimer() + // Drop Caches and bazel clean should happen inside the loop as we may use + // time options with b.N. (e.g. Run for an hour.) + for i := 0; i < b.N; i++ { + b.StopTimer() + // Drop Caches for clear cache runs. + if bm.clearCache { + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + } + b.StartTimer() + + got, err := container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: workdir, + }, "bazel", "build", "-c", "opt", target) + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StopTimer() + + want := "Build completed successfully" + if !strings.Contains(got, want) { + b.Fatalf("string %s not in: %s", want, got) + } + // Clean bazel in case we use b.N. + _, err = container.Exec(ctx, dockerutil.ExecOpts{ + WorkDir: workdir, + }, "bazel", "clean") + if err != nil { + b.Fatalf("build failed with: %v", err) + } + b.StartTimer() + } + }) + } +} diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go new file mode 100644 index 000000000..75d52726a --- /dev/null +++ b/test/benchmarks/fs/fio_test.go @@ -0,0 +1,369 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package fs + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +type fioTestCase struct { + test string // test to run: read, write, randread, randwrite. + size string // total size to be read/written of format N[GMK] (e.g. 5G). + blocksize string // blocksize to be read/write of format N[GMK] (e.g. 4K). + iodepth int // iodepth for reads/writes. + time int // time to run the test in seconds, usually for rand(read/write). +} + +// makeCmdFromTestcase makes a fio command. +func (f *fioTestCase) makeCmdFromTestcase(filename string) []string { + cmd := []string{"fio", "--output-format=json", "--ioengine=sync"} + cmd = append(cmd, fmt.Sprintf("--name=%s", f.test)) + cmd = append(cmd, fmt.Sprintf("--size=%s", f.size)) + cmd = append(cmd, fmt.Sprintf("--blocksize=%s", f.blocksize)) + cmd = append(cmd, fmt.Sprintf("--filename=%s", filename)) + cmd = append(cmd, fmt.Sprintf("--iodepth=%d", f.iodepth)) + cmd = append(cmd, fmt.Sprintf("--rw=%s", f.test)) + if f.time != 0 { + cmd = append(cmd, "--time_based") + cmd = append(cmd, fmt.Sprintf("--runtime=%d", f.time)) + } + return cmd +} + +// BenchmarkFio runs fio on the runtime under test. There are 4 basic test +// cases each run on a tmpfs mount and a bind mount. Fio requires root so that +// caches can be dropped. +func BenchmarkFio(b *testing.B) { + testCases := []fioTestCase{ + fioTestCase{ + test: "write", + size: "5G", + blocksize: "1M", + iodepth: 4, + }, + fioTestCase{ + test: "read", + size: "5G", + blocksize: "1M", + iodepth: 4, + }, + fioTestCase{ + test: "randwrite", + size: "5G", + blocksize: "4K", + iodepth: 4, + time: 30, + }, + fioTestCase{ + test: "randread", + size: "5G", + blocksize: "4K", + iodepth: 4, + time: 30, + }, + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} { + for _, tc := range testCases { + testName := strings.Title(tc.test) + strings.Title(string(fsType)) + b.Run(testName, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + // Directory and filename inside container where fio will read/write. + outdir := "/data" + outfile := filepath.Join(outdir, "test.txt") + + // Make the required mount and grab a cleanup for bind mounts + // as they are backed by a temp directory (mktemp). + mnt, mountCleanup, err := makeMount(machine, fsType, outdir) + if err != nil { + b.Fatalf("failed to make mount: %v", err) + } + defer mountCleanup() + cmd := tc.makeCmdFromTestcase(outfile) + + // Start the container with the mount. + if err := container.Spawn( + ctx, + dockerutil.RunOpts{ + Image: "benchmarks/fio", + Mounts: []mount.Mount{ + mnt, + }, + }, + // Sleep on the order of b.N. + "sleep", fmt.Sprintf("%d", 1000*b.N), + ); err != nil { + b.Fatalf("failed to start fio container with: %v", err) + } + + // For reads, we need a file to read so make one inside the container. + if strings.Contains(tc.test, "read") { + fallocateCmd := fmt.Sprintf("fallocate -l %s %s", tc.size, outfile) + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + strings.Split(fallocateCmd, " ")...); err != nil { + b.Fatalf("failed to create readable file on mount: %v, %s", err, out) + } + } + + // Drop caches just before running. + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches with %v. You probably need root.", err) + } + container.RestartProfiles() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Run fio. + data, err := container.Exec(ctx, dockerutil.ExecOpts{}, cmd...) + if err != nil { + b.Fatalf("failed to run cmd %v: %v", cmd, err) + } + b.StopTimer() + // Parse the output and report the metrics. + isRead := strings.Contains(tc.test, "read") + bw, err := parseBandwidth(data, isRead) + if err != nil { + b.Fatalf("failed to parse bandwidth from %s with: %v", data, err) + } + b.ReportMetric(bw, "bandwidth") // in b/s. + + iops, err := parseIOps(data, isRead) + if err != nil { + b.Fatalf("failed to parse iops from %s with: %v", data, err) + } + b.ReportMetric(iops, "iops") + // If b.N is used (i.e. we run for an hour), we should drop caches + // after each run. + if err := harness.DropCaches(machine); err != nil { + b.Fatalf("failed to drop caches: %v", err) + } + b.StartTimer() + } + }) + } + } +} + +// makeMount makes a mount and cleanup based on the requested type. Bind +// and volume mounts are backed by a temp directory made with mktemp. +// tmpfs mounts require no such backing and are just made. +// It is up to the caller to call the returned cleanup. +func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) { + switch mountType { + case mount.TypeVolume, mount.TypeBind: + dir, err := machine.RunCommand("mktemp", "-d") + if err != nil { + return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err) + } + dir = strings.TrimSuffix(dir, "\n") + + out, err := machine.RunCommand("chmod", "777", dir) + if err != nil { + machine.RunCommand("rm", "-rf", dir) + return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) + } + return mount.Mount{ + Target: target, + Source: dir, + Type: mount.TypeBind, + }, func() { machine.RunCommand("rm", "-rf", dir) }, nil + case mount.TypeTmpfs: + return mount.Mount{ + Target: target, + Type: mount.TypeTmpfs, + }, func() {}, nil + default: + return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType) + } +} + +// parseBandwidth reports the bandwidth in b/s. +func parseBandwidth(data string, isRead bool) (float64, error) { + if isRead { + result, err := parseFioJSON(data, "read", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil + } + result, err := parseFioJSON(data, "write", "bw") + if err != nil { + return 0, err + } + return 1024 * result, nil +} + +// parseIOps reports the write IO per second metric. +func parseIOps(data string, isRead bool) (float64, error) { + if isRead { + return parseFioJSON(data, "read", "iops") + } + return parseFioJSON(data, "write", "iops") +} + +// fioResult is for parsing FioJSON. +type fioResult struct { + Jobs []fioJob +} + +// fioJob is for parsing FioJSON. +type fioJob map[string]json.RawMessage + +// fioMetrics is for parsing FioJSON. +type fioMetrics map[string]json.RawMessage + +// parseFioJSON parses data and grabs "op" (read or write) and "metric" +// (bw or iops) from the JSON. +func parseFioJSON(data, op, metric string) (float64, error) { + var result fioResult + if err := json.Unmarshal([]byte(data), &result); err != nil { + return 0, fmt.Errorf("could not unmarshal data: %v", err) + } + + if len(result.Jobs) < 1 { + return 0, fmt.Errorf("no jobs present to parse") + } + + var metrics fioMetrics + if err := json.Unmarshal(result.Jobs[0][op], &metrics); err != nil { + return 0, fmt.Errorf("could not unmarshal jobs: %v", err) + } + + if _, ok := metrics[metric]; !ok { + return 0, fmt.Errorf("no metric found for op: %s", op) + } + return strconv.ParseFloat(string(metrics[metric]), 64) +} + +// TestParsers tests that the parsers work on sampleData. +func TestParsers(t *testing.T) { + sampleData := ` +{ + "fio version" : "fio-3.1", + "timestamp" : 1554837456, + "timestamp_ms" : 1554837456621, + "time" : "Tue Apr 9 19:17:36 2019", + "jobs" : [ + { + "jobname" : "test", + "groupid" : 0, + "error" : 0, + "eta" : 2147483647, + "elapsed" : 1, + "job options" : { + "name" : "test", + "ioengine" : "sync", + "size" : "1073741824", + "filename" : "/disk/file.dat", + "iodepth" : "4", + "bs" : "4096", + "rw" : "write" + }, + "read" : { + "io_bytes" : 0, + "io_kbytes" : 0, + "bw" : 123456, + "iops" : 1234.5678, + "runtime" : 0, + "total_ios" : 0, + "short_ios" : 0, + "bw_min" : 0, + "bw_max" : 0, + "bw_agg" : 0.000000, + "bw_mean" : 0.000000, + "bw_dev" : 0.000000, + "bw_samples" : 0, + "iops_min" : 0, + "iops_max" : 0, + "iops_mean" : 0.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 0 + }, + "write" : { + "io_bytes" : 1073741824, + "io_kbytes" : 1048576, + "bw" : 1753471, + "iops" : 438367.892977, + "runtime" : 598, + "total_ios" : 262144, + "bw_min" : 1731120, + "bw_max" : 1731120, + "bw_agg" : 98.725328, + "bw_mean" : 1731120.000000, + "bw_dev" : 0.000000, + "bw_samples" : 1, + "iops_min" : 432780, + "iops_max" : 432780, + "iops_mean" : 432780.000000, + "iops_stddev" : 0.000000, + "iops_samples" : 1 + } + } + ] +} +` + // WriteBandwidth. + got, err := parseBandwidth(sampleData, false) + var want float64 = 1753471.0 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadBandwidth. + got, err = parseBandwidth(sampleData, true) + want = 123456 * 1024 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // WriteIOps. + got, err = parseIOps(sampleData, false) + want = 438367.892977 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + // ReadIOps. + got, err = parseIOps(sampleData, true) + want = 1234.5678 + if err != nil { + t.Fatalf("parse failed with err: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/fs/fs.go b/test/benchmarks/fs/fs.go new file mode 100644 index 000000000..e5ca28c3b --- /dev/null +++ b/test/benchmarks/fs/fs.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fs holds benchmarks around filesystem performance. +package fs + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package fs. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD new file mode 100644 index 000000000..c2e316709 --- /dev/null +++ b/test/benchmarks/harness/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "harness", + testonly = 1, + srcs = [ + "harness.go", + "machine.go", + "util.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) diff --git a/test/benchmarks/harness/harness.go b/test/benchmarks/harness/harness.go new file mode 100644 index 000000000..68bd7b4cf --- /dev/null +++ b/test/benchmarks/harness/harness.go @@ -0,0 +1,38 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package harness holds utility code for running benchmarks on Docker. +package harness + +import ( + "flag" + + "gvisor.dev/gvisor/pkg/test/dockerutil" +) + +// Harness is a handle for managing state in benchmark runs. +type Harness struct { +} + +// Init performs any harness initilialization before runs. +func (h *Harness) Init() error { + flag.Parse() + dockerutil.EnsureSupportedDockerVersion() + return nil +} + +// GetMachine returns this run's implementation of machine. +func (h *Harness) GetMachine() (Machine, error) { + return &localMachine{}, nil +} diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go new file mode 100644 index 000000000..88e5e841b --- /dev/null +++ b/test/benchmarks/harness/machine.go @@ -0,0 +1,81 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "net" + "os/exec" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// Machine describes a real machine for use in benchmarks. +type Machine interface { + // GetContainer gets a container from the machine. The container uses the + // runtime under test and is profiled if requested by flags. + GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // GetNativeContainer gets a native container from the machine. Native containers + // use runc by default and are not profiled. + GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container + + // RunCommand runs cmd on this machine. + RunCommand(cmd string, args ...string) (string, error) + + // Returns IP Address for the machine. + IPAddress() (net.IP, error) + + // CleanUp cleans up this machine. + CleanUp() +} + +// localMachine describes this machine. +type localMachine struct { +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeContainer(ctx, logger) +} + +// GetContainer implements Machine.GetContainer for localMachine. +func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container { + return dockerutil.MakeNativeContainer(ctx, logger) +} + +// RunCommand implements Machine.RunCommand for localMachine. +func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) { + c := exec.Command(cmd, args...) + out, err := c.CombinedOutput() + return string(out), err +} + +// IPAddress implements Machine.IPAddress. +func (l *localMachine) IPAddress() (net.IP, error) { + conn, err := net.Dial("udp", "8.8.8.8:80") + if err != nil { + return nil, err + } + defer conn.Close() + + addr := conn.LocalAddr().(*net.UDPAddr) + return addr.IP, nil +} + +// CleanUp implements Machine.CleanUp and does nothing for localMachine. +func (*localMachine) CleanUp() { +} diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go new file mode 100644 index 000000000..bc551c582 --- /dev/null +++ b/test/benchmarks/harness/util.go @@ -0,0 +1,46 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package harness + +import ( + "context" + "fmt" + "net" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// WaitUntilServing grabs a container from `machine` and waits for a server at +// IP:port. +func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error { + var logger testutil.DefaultLogger = "netcat" + netcat := machine.GetNativeContainer(ctx, logger) + defer netcat.CleanUp(ctx) + + cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server, port) + _, err := netcat.Run(ctx, dockerutil.RunOpts{ + Image: "packetdrill", + }, "sh", "-c", cmd) + return err +} + +// DropCaches drops caches on the provided machine. Requires root. +func DropCaches(machine Machine) error { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync | sysctl vm.drop_caches=3"); err != nil { + return fmt.Errorf("failed to drop caches: %v logs: %s", err, out) + } + return nil +} diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD new file mode 100644 index 000000000..6c41fc4f6 --- /dev/null +++ b/test/benchmarks/media/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "media", + testonly = 1, + srcs = ["media.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "media_test", + size = "large", + srcs = ["ffmpeg_test.go"], + library = ":media", + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go new file mode 100644 index 000000000..7822dfad7 --- /dev/null +++ b/test/benchmarks/media/ffmpeg_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package media + +import ( + "context" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkFfmpeg runs ffmpeg in a container and records runtime. +// BenchmarkFfmpeg should run as root to drop caches. +func BenchmarkFfmpeg(b *testing.B) { + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ffmpeg", + }, cmd...); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go new file mode 100644 index 000000000..c7b35b758 --- /dev/null +++ b/test/benchmarks/media/media.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package media holds benchmarks around media processing applications. +package media + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package media. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD new file mode 100644 index 000000000..2430b60a7 --- /dev/null +++ b/test/benchmarks/ml/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "ml", + testonly = 1, + srcs = ["ml.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "ml_test", + size = "large", + srcs = ["tensorflow_test.go"], + library = ":ml", + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/ml/ml.go b/test/benchmarks/ml/ml.go new file mode 100644 index 000000000..13282d7bb --- /dev/null +++ b/test/benchmarks/ml/ml.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ml holds benchmarks around machine learning performance. +package ml + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package ml. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/ml/tensorflow_test.go b/test/benchmarks/ml/tensorflow_test.go new file mode 100644 index 000000000..f7746897d --- /dev/null +++ b/test/benchmarks/ml/tensorflow_test.go @@ -0,0 +1,69 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package ml + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkTensorflow runs workloads from a TensorFlow tutorial. +// See: https://github.com/aymericdamien/TensorFlow-Examples +func BenchmarkTensorflow(b *testing.B) { + workloads := map[string]string{ + "GradientDecisionTree": "2_BasicModels/gradient_boosted_decision_tree.py", + "Kmeans": "2_BasicModels/kmeans.py", + "LogisticRegression": "2_BasicModels/logistic_regression.py", + "NearestNeighbor": "2_BasicModels/nearest_neighbor.py", + "RandomForest": "2_BasicModels/random_forest.py", + "ConvolutionalNetwork": "3_NeuralNetworks/convolutional_network.py", + "MultilayerPerceptron": "3_NeuralNetworks/multilayer_perceptron.py", + "NeuralNetwork": "3_NeuralNetworks/neural_network.py", + } + + machine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for name, workload := range workloads { + b.Run(name, func(b *testing.B) { + ctx := context.Background() + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + if err := harness.DropCaches(machine); err != nil { + b.Skipf("failed to drop caches: %v. You probably need root.", err) + } + b.StartTimer() + + if out, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/tensorflow", + Env: []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"}, + WorkDir: "/TensorFlow-Examples/examples", + }, "python", workload); err != nil { + b.Fatalf("failed to run container: %v logs: %s", err, out) + } + } + }) + } + +} diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD new file mode 100644 index 000000000..b47400590 --- /dev/null +++ b/test/benchmarks/network/BUILD @@ -0,0 +1,31 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "network", + testonly = 1, + srcs = ["network.go"], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "network_test", + size = "large", + srcs = [ + "httpd_test.go", + "iperf_test.go", + "node_test.go", + ], + library = ":network", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + deps = [ + "//pkg/test/dockerutil", + "//pkg/test/testutil", + "//test/benchmarks/harness", + ], +) diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go new file mode 100644 index 000000000..fe23ca949 --- /dev/null +++ b/test/benchmarks/network/httpd_test.go @@ -0,0 +1,277 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "regexp" + "strconv" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// see Dockerfile '//images/benchmarks/httpd'. +var docs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1000Kb": "latin1000k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + +// BenchmarkHttpdConcurrency iterates the concurrency argument and tests +// how well the runtime under test handles requests in parallel. +func BenchmarkHttpdConcurrency(b *testing.B) { + // Grab a machine for the client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get client: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get server: %v", err) + } + defer serverMachine.CleanUp() + + // The test iterates over client concurrency, so set other parameters. + requests := 10000 + concurrency := []int{1, 5, 10, 25} + doc := docs["10Kb"] + + for _, c := range concurrency { + b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { + runHttpd(b, clientMachine, serverMachine, doc, requests, c) + }) + } +} + +// BenchmarkHttpdDocSize iterates over different sized payloads, testing how +// well the runtime handles different payload sizes. +func BenchmarkHttpdDocSize(b *testing.B) { + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + requests := 10000 + concurrency := 1 + + for name, filename := range docs { + b.Run(name, func(b *testing.B) { + runHttpd(b, clientMachine, serverMachine, filename, requests, concurrency) + }) + } +} + +// runHttpd runs a single test run. +func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, doc string, requests, concurrency int) { + b.Helper() + + // Grab a container from the server. + ctx := context.Background() + server := serverMachine.GetContainer(ctx, b) + defer server.CleanUp(ctx) + + // Copy the docs to /tmp and serve from there. + cmd := "mkdir -p /tmp/html; cp -r /local /tmp/html/.; apache2 -X" + port := 80 + + // Start the server. + server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/httpd", + Ports: []int{port}, + Env: []string{ + // Standard environmental variables for httpd. + "APACHE_RUN_DIR=/tmp", + "APACHE_RUN_USER=nobody", + "APACHE_RUN_GROUP=nogroup", + "APACHE_LOG_DIR=/tmp", + "APACHE_PID_FILE=/tmp/apache.pid", + }, + }, "sh", "-c", cmd) + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Check the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + + // Grab a client. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + + path := fmt.Sprintf("http://%s:%d/%s", ip, servingPort, doc) + // See apachebench (ab) for flags. + cmd = fmt.Sprintf("ab -n %d -c %d %s", requests, concurrency, path) + + b.ResetTimer() + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ab", + }, "sh", "-c", cmd) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + + b.StopTimer() + + // Parse and report custom metrics. + transferRate, err := parseTransferRate(out) + if err != nil { + b.Logf("failed to parse transferrate: %v", err) + } + b.ReportMetric(transferRate*1024, "transfer_rate") // Convert from Kb/s to b/s. + + latency, err := parseLatency(out) + if err != nil { + b.Logf("failed to parse latency: %v", err) + } + b.ReportMetric(latency/1000, "mean_latency") // Convert from ms to s. + + reqPerSecond, err := parseRequestsPerSecond(out) + if err != nil { + b.Logf("failed to parse requests per second: %v", err) + } + b.ReportMetric(reqPerSecond, "requests_per_second") + + b.StartTimer() + } +} + +var transferRateRE = regexp.MustCompile(`Transfer rate:\s+(\d+\.?\d+?)\s+\[Kbytes/sec\]\s+received`) + +// parseTransferRate parses transfer rate from apachebench output. +func parseTransferRate(data string) (float64, error) { + match := transferRateRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var latencyRE = regexp.MustCompile(`Total:\s+\d+\s+(\d+)\s+(\d+\.?\d+?)\s+\d+\s+\d+\s`) + +// parseLatency parses latency from apachebench output. +func parseLatency(data string) (float64, error) { + match := latencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var requestsPerSecondRE = regexp.MustCompile(`Requests per second:\s+(\d+\.?\d+?)\s+`) + +// parseRequestsPerSecond parses requests per second from apachebench output. +func parseRequestsPerSecond(data string) (float64, error) { + match := requestsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// Sample output from apachebench. +const sampleData = `This is ApacheBench, Version 2.3 <$Revision: 1826891 $> +Copyright 1996 Adam Twiss, Zeus Technology Ltd, http://www.zeustech.net/ +Licensed to The Apache Software Foundation, http://www.apache.org/ + +Benchmarking 10.10.10.10 (be patient).....done + + +Server Software: Apache/2.4.38 +Server Hostname: 10.10.10.10 +Server Port: 80 + +Document Path: /latin10k.txt +Document Length: 210 bytes + +Concurrency Level: 1 +Time taken for tests: 0.180 seconds +Complete requests: 100 +Failed requests: 0 +Non-2xx responses: 100 +Total transferred: 38800 bytes +HTML transferred: 21000 bytes +Requests per second: 556.44 [#/sec] (mean) +Time per request: 1.797 [ms] (mean) +Time per request: 1.797 [ms] (mean, across all concurrent requests) +Transfer rate: 210.84 [Kbytes/sec] received + +Connection Times (ms) + min mean[+/-sd] median max +Connect: 0 0 0.2 0 2 +Processing: 1 2 1.0 1 8 +Waiting: 1 1 1.0 1 7 +Total: 1 2 1.2 1 10 + +Percentage of the requests served within a certain time (ms) + 50% 1 + 66% 2 + 75% 2 + 80% 2 + 90% 2 + 95% 3 + 98% 7 + 99% 10 + 100% 10 (longest request)` + +// TestParsers checks the parsers work. +func TestParsers(t *testing.T) { + want := 210.84 + got, err := parseTransferRate(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseTransferRate got: %f, want: %f", got, want) + } + + want = 2.0 + got, err = parseLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseLatency got: %f, want: %f", got, want) + } + + want = 556.44 + got, err = parseRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse transfer rate with error: %v", err) + } else if got != want { + t.Fatalf("parseRequestsPerSecond got: %f, want: %f", got, want) + } +} diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go new file mode 100644 index 000000000..a5e198e14 --- /dev/null +++ b/test/benchmarks/network/iperf_test.go @@ -0,0 +1,150 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +func BenchmarkIperf(b *testing.B) { + const time = 10 // time in seconds to run the client. + + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + ctx := context.Background() + for _, bm := range []struct { + name string + clientFunc func(context.Context, testutil.Logger) *dockerutil.Container + serverFunc func(context.Context, testutil.Logger) *dockerutil.Container + }{ + // We are either measuring the server or the client. The other should be + // runc. e.g. Upload sees how fast the runtime under test uploads to a native + // server. + { + name: "Upload", + clientFunc: clientMachine.GetContainer, + serverFunc: serverMachine.GetNativeContainer, + }, + { + name: "Download", + clientFunc: clientMachine.GetNativeContainer, + serverFunc: serverMachine.GetContainer, + }, + } { + b.Run(bm.name, func(b *testing.B) { + // Set up the containers. + server := bm.serverFunc(ctx, b) + defer server.CleanUp(ctx) + client := bm.clientFunc(ctx, b) + defer client.CleanUp(ctx) + + // iperf serves on port 5001 by default. + port := 5001 + + // Start the server. + if err := server.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + Ports: []int{port}, + }, "iperf", "-s"); err != nil { + b.Fatalf("failed to start server with: %v", err) + } + + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find port %d: %v", port, err) + } + + // Make sure the server is up and serving before we run. + if err := harness.WaitUntilServing(ctx, clientMachine, ip, servingPort); err != nil { + b.Fatalf("failed to wait for server: %v", err) + } + + // iperf report in Kb realtime + cmd := fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", time, ip.String(), servingPort) + + // Run the client. + b.ResetTimer() + + // Restart the server profiles. If the server isn't being profiled + // this does nothing. + server.RestartProfiles() + for i := 0; i < b.N; i++ { + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/iperf", + }, strings.Split(cmd, " ")...) + if err != nil { + b.Fatalf("failed to run client: %v", err) + } + b.StopTimer() + + // Parse bandwidth and report it. + bW, err := bandwidth(out) + if err != nil { + b.Fatalf("failed to parse bandwitdth from %s: %v", out, err) + } + b.ReportMetric(bW*1024, "bandwidth") // Convert from Kb/s to b/s. + b.StartTimer() + } + }) + } +} + +// bandwidth parses the Bandwidth number from an iperf report. A sample is below. +func bandwidth(data string) (float64, error) { + re := regexp.MustCompile(`\[\s*\d+\][^\n]+\s+(\d+\.?\d*)\s+KBytes/sec`) + match := re.FindStringSubmatch(data) + if len(match) < 1 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +func TestParser(t *testing.T) { + sampleData := ` +------------------------------------------------------------ +Client connecting to 10.138.15.215, TCP port 32779 +TCP window size: 45.0 KByte (default) +------------------------------------------------------------ +[ 3] local 10.138.15.216 port 32866 connected with 10.138.15.215 port 32779 +[ ID] Interval Transfer Bandwidth +[ 3] 0.0-10.0 sec 459520 KBytes 45900 KBytes/sec +` + bandwidth, err := bandwidth(sampleData) + if err != nil || bandwidth != 45900 { + t.Fatalf("failed with: %v and %f", err, bandwidth) + } +} diff --git a/test/benchmarks/network/network.go b/test/benchmarks/network/network.go new file mode 100644 index 000000000..ce17ddb94 --- /dev/null +++ b/test/benchmarks/network/network.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package network holds benchmarks around raw network performance. +package network + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var h harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + h.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go new file mode 100644 index 000000000..f9278ab66 --- /dev/null +++ b/test/benchmarks/network/node_test.go @@ -0,0 +1,261 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkNode runs 10K requests using 'hey' against a Node server run on +// 'runtime'. The server responds to requests by grabbing some data in a +// redis instance and returns the data in its reponse. The test loops through +// increasing amounts of concurency for requests. +func BenchmarkNode(b *testing.B) { + requests := 10000 + concurrency := []int{1, 5, 10, 25} + + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + runNode(b, requests, c) + }) + } +} + +// runNode runs the test for a given # of requests and concurrency. +func runNode(b *testing.B, requests, concurrency int) { + b.Helper() + + // The machine to hold Redis and the Node Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatal("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Node runs on port 8080. + port := 8080 + + // Start-up the Node server. + nodeApp := serverMachine.GetContainer(ctx, b) + if err := nodeApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + }, "node", "index.js", redisIP.String()); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer nodeApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := nodeApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort) + + heyCmd := strings.Split(fmt.Sprintf("hey -n %d -c %d http://%s:%d/", requests, concurrency, servingIP, servingPort), " ") + + nodeApp.RestartProfiles() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + requests, err := parseHeyRequestsPerSecond(out) + if err != nil { + b.Fatalf("failed to parse requests per second: %v", err) + } + b.ReportMetric(requests, "requests_per_second") + + bw, err := parseHeyBandwidth(out) + if err != nil { + b.Fatalf("failed to parse bandwidth: %v", err) + } + b.ReportMetric(bw, "bandwidth") + + ave, err := parseHeyAverageLatency(out) + if err != nil { + b.Fatalf("failed to parse average latency: %v", err) + } + b.ReportMetric(ave, "average_latency") + b.StartTimer() + } +} + +var heyReqPerSecondRE = regexp.MustCompile(`Requests/sec:\s*(\d+\.?\d+?)\s+`) + +// parseHeyRequestsPerSecond finds requests per second from hey output. +func parseHeyRequestsPerSecond(data string) (float64, error) { + match := heyReqPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get bandwidth: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +var heyAverageLatencyRE = regexp.MustCompile(`Average:\s*(\d+\.?\d+?)\s+secs`) + +// parseHeyAverageLatency finds Average Latency in seconds form hey output. +func parseHeyAverageLatency(data string) (float64, error) { + match := heyAverageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get average latency match%d : %s", len(match), data) + } + return strconv.ParseFloat(match[1], 64) +} + +var heySizePerRequestRE = regexp.MustCompile(`Size/request:\s*(\d+\.?\d+?)\s+bytes`) + +// parseHeyBandwidth computes bandwidth from request/sec * bytes/request +// and reports in bytes/second. +func parseHeyBandwidth(data string) (float64, error) { + match := heyReqPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get requests per second: %s", data) + } + reqPerSecond, err := strconv.ParseFloat(match[1], 64) + if err != nil { + return 0, fmt.Errorf("failed to convert %s to float", match[1]) + } + + match = heySizePerRequestRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("failed get average latency: %s", data) + } + requestSize, err := strconv.ParseFloat(match[1], 64) + return requestSize * reqPerSecond, err +} + +// TestHeyParsers tests that the parsers work with sample output. +func TestHeyParsers(t *testing.T) { + sampleData := ` + Summary: + Total: 2.2391 secs + Slowest: 1.6292 secs + Fastest: 0.0066 secs + Average: 0.5351 secs + Requests/sec: 89.3202 + + Total data: 841200 bytes + Size/request: 4206 bytes + + Response time histogram: + 0.007 [1] | + 0.169 [0] | + 0.331 [149] |■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ + 0.493 [0] | + 0.656 [0] | + 0.818 [0] | + 0.980 [0] | + 1.142 [0] | + 1.305 [0] | + 1.467 [49] |■■■■■■■■■■■■■ + 1.629 [1] | + + + Latency distribution: + 10% in 0.2149 secs + 25% in 0.2449 secs + 50% in 0.2703 secs + 75% in 1.3315 secs + 90% in 1.4045 secs + 95% in 1.4232 secs + 99% in 1.4362 secs + + Details (average, fastest, slowest): + DNS+dialup: 0.0002 secs, 0.0066 secs, 1.6292 secs + DNS-lookup: 0.0000 secs, 0.0000 secs, 0.0000 secs + req write: 0.0000 secs, 0.0000 secs, 0.0012 secs + resp wait: 0.5225 secs, 0.0064 secs, 1.4346 secs + resp read: 0.0122 secs, 0.0001 secs, 0.2006 secs + + Status code distribution: + [200] 200 responses + ` + want := 89.3202 + got, err := parseHeyRequestsPerSecond(sampleData) + if err != nil { + t.Fatalf("failed to parse request per second with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + want = 89.3202 * 4206 + got, err = parseHeyBandwidth(sampleData) + if err != nil { + t.Fatalf("failed to parse bandwidth with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + + want = 0.5351 + got, err = parseHeyAverageLatency(sampleData) + if err != nil { + t.Fatalf("failed to parse average latency with: %v", err) + } else if got != want { + t.Fatalf("got: %f, want: %f", got, want) + } + +} diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 44cce0e3b..29a84f184 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -23,6 +23,7 @@ go_test( "//pkg/test/dockerutil", "//pkg/test/testutil", "//runsc/specutils", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go index 6a63b1232..b47df447c 100644 --- a/test/e2e/exec_test.go +++ b/test/e2e/exec_test.go @@ -22,12 +22,10 @@ package integration import ( + "context" "fmt" - "os" - "os/exec" "strconv" "strings" - "syscall" "testing" "time" @@ -39,18 +37,19 @@ import ( // Test that exec uses the exact same capability set as the container. func TestExecCapabilities(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { t.Fatalf("docker run failed: %v", err) } // Check that capability. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -61,7 +60,7 @@ func TestExecCapabilities(t *testing.T) { t.Log("Root capabilities:", want) // Now check that exec'd process capabilities match the root. - got, err := d.Exec(dockerutil.RunOpts{}, "grep", "CapEff:", "/proc/self/status") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "grep", "CapEff:", "/proc/self/status") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -74,11 +73,12 @@ func TestExecCapabilities(t *testing.T) { // Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW // which is removed from the container when --net-raw=false. func TestExecPrivileged(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with all capabilities dropped. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", CapDrop: []string{"all"}, }, "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil { @@ -86,7 +86,7 @@ func TestExecPrivileged(t *testing.T) { } // Check that all capabilities where dropped from container. - matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second) + matches, err := d.WaitForOutputSubmatch(ctx, "CapEff:\t([0-9a-f]+)\n", 5*time.Second) if err != nil { t.Fatalf("WaitForOutputSubmatch() timeout: %v", err) } @@ -104,7 +104,7 @@ func TestExecPrivileged(t *testing.T) { // Check that 'exec --privileged' adds all capabilities, except for // CAP_NET_RAW. - got, err := d.Exec(dockerutil.RunOpts{ + got, err := d.Exec(ctx, dockerutil.ExecOpts{ Privileged: true, }, "grep", "CapEff:", "/proc/self/status") if err != nil { @@ -118,76 +118,59 @@ func TestExecPrivileged(t *testing.T) { } func TestExecJobControl(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } - // Exec 'sh' with an attached pty. - if _, err := d.Exec(dockerutil.RunOpts{ - Pty: func(cmd *exec.Cmd, ptmx *os.File) { - // Call "sleep 100 | cat" in the shell. We pipe to cat - // so that there will be two processes in the - // foreground process group. - if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) - - // Send a ^C to the pty, which should kill sleep and - // cat, but not the shell. \x03 is ASCII "end of - // text", which is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // The shell should still be alive at this point. Sleep - // should have exited with code 2+128=130. We'll exit - // with 10 plus that number, so that we can be sure - // that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Exec process should exit with code 10+130=140. - ps, err := cmd.Process.Wait() - if err != nil { - t.Fatalf("error waiting for exec process: %v", err) - } - ws := ps.Sys().(syscall.WaitStatus) - if !ws.Exited() { - t.Errorf("ws.Exited got false, want true") - } - if got, want := ws.ExitStatus(), 140; got != want { - t.Errorf("ws.ExitedStatus got %d, want %d", got, want) - } - }, - }, "sh"); err != nil { + p, err := d.ExecProcess(ctx, dockerutil.ExecOpts{UseTTY: true}, "/bin/sh") + if err != nil { t.Fatalf("docker exec failed: %v", err) } + + if _, err = p.Write(time.Second, []byte("sleep 100 | cat\n")); err != nil { + t.Fatalf("error exit: %v", err) + } + time.Sleep(time.Second) + + if _, err = p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) + } + + if _, err = p.Write(time.Second, []byte("exit $(expr $? + 10)\n")); err != nil { + t.Fatalf("error exit: %v", err) + } + + want := 140 + got, err := p.WaitExitStatus(ctx) + if err != nil { + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("wait for exit returned: %d want: %d", got, want) + } } // Test that failure to exec returns proper error message. func TestExecError(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } // Attempt to exec a binary that doesn't exist. - out, err := d.Exec(dockerutil.RunOpts{}, "no_can_find") + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "no_can_find") if err == nil { t.Fatalf("docker exec didn't fail") } @@ -198,11 +181,12 @@ func TestExecError(t *testing.T) { // Test that exec inherits environment from run. func TestExecEnv(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with env FOO=BAR. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", Env: []string{"FOO=BAR"}, }, "sleep", "1000"); err != nil { @@ -210,7 +194,7 @@ func TestExecEnv(t *testing.T) { } // Exec "echo $FOO". - got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $FOO") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $FOO") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -222,11 +206,12 @@ func TestExecEnv(t *testing.T) { // TestRunEnvHasHome tests that run always has HOME environment set. func TestRunEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Exec "echo $HOME". The 'bin' user's home dir is '/bin'. - got, err := d.Run(dockerutil.RunOpts{ + got, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", User: "bin", }, "/bin/sh", "-c", "echo $HOME") @@ -243,17 +228,18 @@ func TestRunEnvHasHome(t *testing.T) { // Test that exec always has HOME environment set, even when not set in run. func TestExecEnvHasHome(t *testing.T) { // Base alpine image does not have any environment variables set. - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "1000"); err != nil { t.Fatalf("docker run failed: %v", err) } // Exec "echo $HOME", and expect to see "/root". - got, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "echo $HOME") + got, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "echo $HOME") if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -265,12 +251,12 @@ func TestExecEnvHasHome(t *testing.T) { newUID := 1234 newHome := "/foo/bar" cmd := fmt.Sprintf("mkdir -p -m 777 %q && adduser foo -D -u %d -h %q", newHome, newUID, newHome) - if _, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", cmd); err != nil { + if _, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker exec failed: %v", err) } // Execute the same as the new user and expect newHome. - got, err = d.Exec(dockerutil.RunOpts{ + got, err = d.Exec(ctx, dockerutil.ExecOpts{ User: strconv.Itoa(newUID), }, "/bin/sh", "-c", "echo $HOME") if err != nil { diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 60e739c6a..6fe6d304f 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -22,24 +22,27 @@ package integration import ( + "context" "flag" "fmt" "io/ioutil" "net" "net/http" "os" - "os/exec" "path/filepath" "strconv" "strings" - "syscall" "testing" "time" + "github.com/docker/docker/api/types/mount" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait is the default wait time used for tests. +const defaultWait = time.Minute + // httpRequestSucceeds sends a request to a given url and checks that the status is OK. func httpRequestSucceeds(client http.Client, server string, port int) error { url := fmt.Sprintf("http://%s:%d", server, port) @@ -56,37 +59,38 @@ func httpRequestSucceeds(client http.Client, server string, port int) error { // TestLifeCycle tests a basic Create/Start/Stop docker container life cycle. func TestLifeCycle(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Create(dockerutil.RunOpts{ + if err := d.Create(ctx, dockerutil.RunOpts{ Image: "basic/nginx", Ports: []int{80}, }); err != nil { t.Fatalf("docker create failed: %v", err) } - if err := d.Start(); err != nil { + if err := d.Start(ctx); err != nil { t.Fatalf("docker start failed: %v", err) } // Test that container is working. - port, err := d.FindPort(80) + port, err := d.FindPort(ctx, 80) if err != nil { t.Fatalf("docker.FindPort(80) failed: %v", err) } - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Errorf("http request failed: %v", err) } - if err := d.Stop(); err != nil { + if err := d.Stop(ctx); err != nil { t.Fatalf("docker stop failed: %v", err) } - if err := d.Remove(); err != nil { + if err := d.Remove(ctx); err != nil { t.Fatalf("docker rm failed: %v", err) } } @@ -96,11 +100,12 @@ func TestPauseResume(t *testing.T) { t.Skip("Checkpoint is not supported.") } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", Ports: []int{8080}, // See Dockerfile. }); err != nil { @@ -108,27 +113,28 @@ func TestPauseResume(t *testing.T) { } // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } - if err := d.Pause(); err != nil { + if err := d.Pause(ctx); err != nil { t.Fatalf("docker pause failed: %v", err) } // Check if container is paused. + client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute. switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { case nil: t.Errorf("http req expected to fail but it succeeded") @@ -140,16 +146,17 @@ func TestPauseResume(t *testing.T) { t.Errorf("http req got unexpected error %v", v) } - if err := d.Unpause(); err != nil { + if err := d.Unpause(ctx); err != nil { t.Fatalf("docker unpause failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. + client = http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } @@ -160,11 +167,12 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", Ports: []int{8080}, // See Dockerfile. }); err != nil { @@ -172,31 +180,31 @@ func TestCheckpointRestore(t *testing.T) { } // Create a snapshot. - if err := d.Checkpoint("test"); err != nil { + if err := d.Checkpoint(ctx, "test"); err != nil { t.Fatalf("docker checkpoint failed: %v", err) } - if _, err := d.Wait(30 * time.Second); err != nil { + if err := d.WaitTimeout(ctx, defaultWait); err != nil { t.Fatalf("wait failed: %v", err) } // TODO(b/143498576): Remove Poll after github.com/moby/moby/issues/38963 is fixed. - if err := testutil.Poll(func() error { return d.Restore("test") }, 15*time.Second); err != nil { + if err := testutil.Poll(func() error { return d.Restore(ctx, "test") }, defaultWait); err != nil { t.Fatalf("docker restore failed: %v", err) } // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { t.Fatalf("docker.FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. - client := http.Client{Timeout: time.Duration(2 * time.Second)} + client := http.Client{Timeout: defaultWait} if err := httpRequestSucceeds(client, "localhost", port); err != nil { t.Error("http request failed:", err) } @@ -204,26 +212,27 @@ func TestCheckpointRestore(t *testing.T) { // Create client and server that talk to each other using the local IP. func TestConnectToSelf(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Creates server that replies "server" and exists. Sleeps at the end because // 'docker exec' gets killed if the init process exists before it can finish. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/ubuntu", }, "/bin/sh", "-c", "echo server | nc -l -p 8080 && sleep 1"); err != nil { t.Fatalf("docker run failed: %v", err) } // Finds IP address for host. - ip, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") + ip, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "cat /etc/hosts | grep ${HOSTNAME} | awk '{print $1}'") if err != nil { t.Fatalf("docker exec failed: %v", err) } ip = strings.TrimRight(ip, "\n") // Runs client that sends "client" to the server and exits. - reply, err := d.Exec(dockerutil.RunOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) + reply, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", fmt.Sprintf("echo client | nc %s 8080", ip)) if err != nil { t.Fatalf("docker exec failed: %v", err) } @@ -232,21 +241,22 @@ func TestConnectToSelf(t *testing.T) { if want := "server\n"; reply != want { t.Errorf("Error on server, want: %q, got: %q", want, reply) } - if _, err := d.WaitForOutput("^client\n$", 1*time.Second); err != nil { + if _, err := d.WaitForOutput(ctx, "^client\n$", defaultWait); err != nil { t.Fatalf("docker.WaitForOutput(client) timeout: %v", err) } } func TestMemLimit(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // N.B. Because the size of the memory file may grow in large chunks, // there is a minimum threshold of 1GB for the MemTotal figure. - allocMemory := 1024 * 1024 - out, err := d.Run(dockerutil.RunOpts{ + allocMemory := 1024 * 1024 // In kb. + out, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", - Memory: allocMemory, // In kB. + Memory: allocMemory * 1024, // In bytes. }, "sh", "-c", "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'") if err != nil { t.Fatalf("docker run failed: %v", err) @@ -272,13 +282,14 @@ func TestMemLimit(t *testing.T) { } func TestNumCPU(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Read how many cores are in the container. - out, err := d.Run(dockerutil.RunOpts{ - Image: "basic/alpine", - Extra: []string{"--cpuset-cpus=0"}, + out, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/alpine", + CpusetCpus: "0", }, "sh", "-c", "cat /proc/cpuinfo | grep 'processor.*:' | wc -l") if err != nil { t.Fatalf("docker run failed: %v", err) @@ -296,48 +307,34 @@ func TestNumCPU(t *testing.T) { // TestJobControl tests that job control characters are handled properly. func TestJobControl(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container with an attached PTY. - if _, err := d.Run(dockerutil.RunOpts{ + p, err := d.SpawnProcess(ctx, dockerutil.RunOpts{ Image: "basic/alpine", - Pty: func(_ *exec.Cmd, ptmx *os.File) { - // Call "sleep 100" in the shell. - if _, err := ptmx.Write([]byte("sleep 100\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - - // Give shell a few seconds to start executing the sleep. - time.Sleep(2 * time.Second) + }, "sh", "-c", "sleep 100 | cat") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + // Give shell a few seconds to start executing the sleep. + time.Sleep(2 * time.Second) - // Send a ^C to the pty, which should kill sleep, but - // not the shell. \x03 is ASCII "end of text", which - // is the same as ^C. - if _, err := ptmx.Write([]byte{'\x03'}); err != nil { - t.Fatalf("error writing to pty: %v", err) - } + if _, err := p.Write(time.Second, []byte{0x03}); err != nil { + t.Fatalf("error exit: %v", err) + } - // The shell should still be alive at this point. Sleep - // should have exited with code 2+128=130. We'll exit - // with 10 plus that number, so that we can be sure - // that the shell did not get signalled. - if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil { - t.Fatalf("error writing to pty: %v", err) - } - }, - }, "sh"); err != nil { - t.Fatalf("docker run failed: %v", err) + if err := d.WaitTimeout(ctx, 3*time.Second); err != nil { + t.Fatalf("WaitTimeout failed: %v", err) } - // Wait for the container to exit. - got, err := d.Wait(5 * time.Second) + want := 130 + got, err := p.WaitExitStatus(ctx) if err != nil { - t.Fatalf("error getting exit code: %v", err) - } - // Container should exit with code 10+130=140. - if want := syscall.WaitStatus(140); got != want { - t.Errorf("container exited with code %d want %d", got, want) + t.Fatalf("wait for exit failed with: %v", err) + } else if got != want { + t.Fatalf("got: %d want: %d", got, want) } } @@ -356,15 +353,16 @@ func TestWorkingDirCreation(t *testing.T) { name += "-readonly" } t.Run(name, func(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) opts := dockerutil.RunOpts{ Image: "basic/alpine", WorkDir: tc.workingDir, ReadOnly: readonly, } - got, err := d.Run(opts, "sh", "-c", "echo ${PWD}") + got, err := d.Run(ctx, opts, "sh", "-c", "echo ${PWD}") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -378,11 +376,12 @@ func TestWorkingDirCreation(t *testing.T) { // TestTmpFile checks that files inside '/tmp' are not overridden. func TestTmpFile(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - opts := dockerutil.RunOpts{Image: "tmpfile"} - got, err := d.Run(opts, "cat", "/tmp/foo/file.txt") + opts := dockerutil.RunOpts{Image: "basic/tmpfile"} + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -393,6 +392,7 @@ func TestTmpFile(t *testing.T) { // TestTmpMount checks that mounts inside '/tmp' are not overridden. func TestTmpMount(t *testing.T) { + ctx := context.Background() dir, err := ioutil.TempDir(testutil.TmpDir(), "tmp-mount") if err != nil { t.Fatalf("TempDir(): %v", err) @@ -401,19 +401,20 @@ func TestTmpMount(t *testing.T) { if err := ioutil.WriteFile(filepath.Join(dir, "file.txt"), []byte("123"), 0666); err != nil { t.Fatalf("WriteFile(): %v", err) } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) opts := dockerutil.RunOpts{ Image: "basic/alpine", - Mounts: []dockerutil.Mount{ + Mounts: []mount.Mount{ { + Type: mount.TypeBind, Source: dir, Target: "/tmp/foo", }, }, } - got, err := d.Run(opts, "cat", "/tmp/foo/file.txt") + got, err := d.Run(ctx, opts, "cat", "/tmp/foo/file.txt") if err != nil { t.Fatalf("docker run failed: %v", err) } @@ -426,14 +427,43 @@ func TestTmpMount(t *testing.T) { // runsc to hide the incoherence of FDs opened before and after overlayfs // copy-up on the host. func TestHostOverlayfsCopyUp(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if _, err := d.Run(dockerutil.RunOpts{ - Image: "hostoverlaytest", + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", + WorkDir: "/root", + }, "./test_copy_up"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) + } +} + +// TestHostOverlayfsRewindDir tests that rewinddir() "causes the directory +// stream to refer to the current state of the corresponding directory, as a +// call to opendir() would have done" as required by POSIX, when the directory +// in question is host overlayfs. +// +// This test specifically targets host overlayfs because, per POSIX, "if a file +// is removed from or added to the directory after the most recent call to +// opendir() or rewinddir(), whether a subsequent call to readdir() returns an +// entry for that file is unspecified"; the host filesystems used by other +// automated tests yield newly-added files from readdir() even if the fsgofer +// does not explicitly rewinddir(), but overlayfs does not. +func TestHostOverlayfsRewindDir(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/hostoverlaytest", WorkDir: "/root", - }, "./test"); err != nil { + }, "./test_rewinddir"); err != nil { t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) } } diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go index 327a2174c..70bbe5121 100644 --- a/test/e2e/regression_test.go +++ b/test/e2e/regression_test.go @@ -15,6 +15,7 @@ package integration import ( + "context" "strings" "testing" @@ -27,11 +28,12 @@ import ( // Prerequisite: the directory where the socket file is created must not have // been open for write before bind(2) is called. func TestBindOverlay(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Run the container. - got, err := d.Run(dockerutil.RunOpts{ + got, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/ubuntu", }, "bash", "-c", "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p") if err != nil { diff --git a/test/image/image_test.go b/test/image/image_test.go index 3e4321480..ac6186688 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -22,6 +22,7 @@ package image import ( + "context" "flag" "fmt" "io/ioutil" @@ -36,12 +37,20 @@ import ( "gvisor.dev/gvisor/pkg/test/testutil" ) +// defaultWait defines how long to wait for progress. +// +// See BUILD: This is at least a "large" test, so allow up to 1 minute for any +// given "wait" step. Note that all tests are run in parallel, which may cause +// individual slow-downs (but a huge speed-up in aggregate). +const defaultWait = time.Minute + func TestHelloWorld(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Run the basic container. - out, err := d.Run(dockerutil.RunOpts{ + out, err := d.Run(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "echo", "Hello world!") if err != nil { @@ -107,8 +116,9 @@ func testHTTPServer(t *testing.T, port int) { } func TestHttpd(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. opts := dockerutil.RunOpts{ @@ -116,18 +126,18 @@ func TestHttpd(t *testing.T) { Ports: []int{80}, } d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt") - if err := d.Spawn(opts); err != nil { + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } // Find where port 80 is mapped to. - port, err := d.FindPort(80) + port, err := d.FindPort(ctx, 80) if err != nil { t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } @@ -135,8 +145,9 @@ func TestHttpd(t *testing.T) { } func TestNginx(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the container. opts := dockerutil.RunOpts{ @@ -144,18 +155,18 @@ func TestNginx(t *testing.T) { Ports: []int{80}, } d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt") - if err := d.Spawn(opts); err != nil { + if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } // Find where port 80 is mapped to. - port, err := d.FindPort(80) + port, err := d.FindPort(ctx, 80) if err != nil { t.Fatalf("FindPort(80) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } @@ -163,11 +174,12 @@ func TestNginx(t *testing.T) { } func TestMysql(t *testing.T) { - server := dockerutil.MakeDocker(t) - defer server.CleanUp() + ctx := context.Background() + server := dockerutil.MakeContainer(ctx, t) + defer server.CleanUp(ctx) // Start the container. - if err := server.Spawn(dockerutil.RunOpts{ + if err := server.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/mysql", Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, }); err != nil { @@ -175,42 +187,38 @@ func TestMysql(t *testing.T) { } // Wait until it's up and running. - if _, err := server.WaitForOutput("port: 3306 MySQL Community Server", 3*time.Minute); err != nil { + if _, err := server.WaitForOutput(ctx, "port: 3306 MySQL Community Server", defaultWait); err != nil { t.Fatalf("WaitForOutput() timeout: %v", err) } // Generate the client and copy in the SQL payload. - client := dockerutil.MakeDocker(t) - defer client.CleanUp() + client := dockerutil.MakeContainer(ctx, t) + defer client.CleanUp(ctx) // Tell mysql client to connect to the server and execute the file in // verbose mode to verify the output. opts := dockerutil.RunOpts{ Image: "basic/mysql", - Links: []dockerutil.Link{ - { - Source: server, - Target: "mysql", - }, - }, + Links: []string{server.MakeLink("mysql")}, } client.CopyFiles(&opts, "/sql", "test/image/mysql.sql") - if _, err := client.Run(opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { + if _, err := client.Run(ctx, opts, "mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql"); err != nil { t.Fatalf("docker run failed: %v", err) } // Ensure file executed to the end and shutdown mysql. - if _, err := server.WaitForOutput("mysqld: Shutdown complete", 30*time.Second); err != nil { + if _, err := server.WaitForOutput(ctx, "mysqld: Shutdown complete", defaultWait); err != nil { t.Fatalf("WaitForOutput() timeout: %v", err) } } func TestTomcat(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start the server. - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/tomcat", Ports: []int{8080}, }); err != nil { @@ -218,13 +226,13 @@ func TestTomcat(t *testing.T) { } // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, 30*time.Second); err != nil { + if err := testutil.WaitForHTTP(port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } @@ -240,8 +248,9 @@ func TestTomcat(t *testing.T) { } func TestRuby(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Execute the ruby workload. opts := dockerutil.RunOpts{ @@ -249,18 +258,18 @@ func TestRuby(t *testing.T) { Ports: []int{8080}, } d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh") - if err := d.Spawn(opts, "/src/ruby.sh"); err != nil { + if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } // Find where port 8080 is mapped to. - port, err := d.FindPort(8080) + port, err := d.FindPort(ctx, 8080) if err != nil { t.Fatalf("FindPort(8080) failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, 1*time.Minute); err != nil { + if err := testutil.WaitForHTTP(port, time.Minute); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } @@ -283,20 +292,21 @@ func TestRuby(t *testing.T) { } func TestStdio(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) wantStdout := "hello stdout" wantStderr := "bonjour stderr" cmd := fmt.Sprintf("echo %q; echo %q 1>&2;", wantStdout, wantStderr) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "/bin/sh", "-c", cmd); err != nil { t.Fatalf("docker run failed: %v", err) } for _, want := range []string{wantStdout, wantStderr} { - if _, err := d.WaitForOutput(want, 5*time.Second); err != nil { + if _, err := d.WaitForOutput(ctx, want, defaultWait); err != nil { t.Fatalf("docker didn't get output %q : %v", want, err) } } diff --git a/test/iptables/BUILD b/test/iptables/BUILD index 3e29ca90d..40b63ebbe 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -20,6 +20,7 @@ go_library( go_test( name = "iptables_test", + size = "large", srcs = [ "iptables_test.go", ], diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 872021358..5737ee317 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -25,7 +25,6 @@ const ( dropPort = 2401 acceptPort = 2402 sendloopDuration = 2 * time.Second - network = "udp4" chainName = "foochain" ) @@ -62,8 +61,8 @@ func (FilterInputDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { +func (FilterInputDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } @@ -80,8 +79,8 @@ func (FilterInputDropUDP) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDP) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. @@ -93,8 +92,8 @@ func (FilterInputDropOnlyUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { +func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } @@ -107,7 +106,7 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropOnlyUDP) LocalAction(ip net.IP) error { +func (FilterInputDropOnlyUDP) LocalAction(ip net.IP, ipv6 bool) error { // Try to establish a TCP connection with the container, which should // succeed. return connectTCP(ip, acceptPort, sendloopDuration) @@ -122,8 +121,8 @@ func (FilterInputDropUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } @@ -140,8 +139,8 @@ func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDPPort) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port @@ -154,8 +153,8 @@ func (FilterInputDropDifferentUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } @@ -168,8 +167,8 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. @@ -181,8 +180,8 @@ func (FilterInputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } @@ -195,7 +194,7 @@ func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPDestPort) LocalAction(ip net.IP) error { +func (FilterInputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. for start := time.Now(); time.Since(start) < sendloopDuration; { if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil { @@ -215,9 +214,9 @@ func (FilterInputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { +func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { // Drop anything from an ephemeral port. - if err := filterTable("-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { + if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { return err } @@ -230,7 +229,7 @@ func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP) error { +func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. for start := time.Now(); time.Since(start) < sendloopDuration; { if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil { @@ -250,8 +249,8 @@ func (FilterInputDropAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropAll) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil { +func (FilterInputDropAll) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil { return err } @@ -268,8 +267,8 @@ func (FilterInputDropAll) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropAll) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // FilterInputMultiUDPRules verifies that multiple UDP rules are applied @@ -284,17 +283,17 @@ func (FilterInputMultiUDPRules) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error { +func (FilterInputMultiUDPRules) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, {"-L"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputMultiUDPRules) LocalAction(ip net.IP) error { +func (FilterInputMultiUDPRules) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -309,14 +308,14 @@ func (FilterInputRequireProtocolUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { +func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") } return nil } -func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP) error { +func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -330,18 +329,18 @@ func (FilterInputCreateUserChain) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error { +func (FilterInputCreateUserChain) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ // Create a chain. {"-N", chainName}, // Add a simple rule to the chain. {"-A", chainName, "-j", "DROP"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputCreateUserChain) LocalAction(ip net.IP) error { +func (FilterInputCreateUserChain) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -355,17 +354,17 @@ func (FilterInputDefaultPolicyAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error { +func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP, ipv6 bool) error { // Set the default policy to accept, then receive a packet. - if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil { + if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil { return err } return listenUDP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputDefaultPolicyDrop tests the default DROP policy. @@ -377,8 +376,8 @@ func (FilterInputDefaultPolicyDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { - if err := filterTable("-P", "INPUT", "DROP"); err != nil { +func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil { return err } @@ -395,8 +394,8 @@ func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes @@ -409,7 +408,7 @@ func (FilterInputReturnUnderflow) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { +func (FilterInputReturnUnderflow) ContainerAction(ip net.IP, ipv6 bool) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. rules := [][]string{ @@ -417,7 +416,7 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { {"-A", "INPUT", "-j", "DROP"}, {"-P", "INPUT", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -427,8 +426,8 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputReturnUnderflow) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputSerializeJump verifies that we can serialize jumps. @@ -440,18 +439,18 @@ func (FilterInputSerializeJump) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSerializeJump) ContainerAction(ip net.IP) error { +func (FilterInputSerializeJump) ContainerAction(ip net.IP, ipv6 bool) error { // Write a JUMP rule, the serialize it with `-L`. rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-L"}, } - return filterTableRules(rules) + return filterTableRules(ipv6, rules) } // LocalAction implements TestCase.LocalAction. -func (FilterInputSerializeJump) LocalAction(ip net.IP) error { +func (FilterInputSerializeJump) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -465,14 +464,14 @@ func (FilterInputJumpBasic) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { +func (FilterInputJumpBasic) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-A", chainName, "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -481,8 +480,8 @@ func (FilterInputJumpBasic) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBasic) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpBasic) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputJumpReturn jumps, returns, and executes a rule. @@ -494,7 +493,7 @@ func (FilterInputJumpReturn) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { +func (FilterInputJumpReturn) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-P", "INPUT", "ACCEPT"}, @@ -502,7 +501,7 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { {"-A", chainName, "-j", "RETURN"}, {"-A", chainName, "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -511,8 +510,8 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturn) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpReturn) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. @@ -524,14 +523,14 @@ func (FilterInputJumpReturnDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { +func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, {"-A", "INPUT", "-j", "DROP"}, {"-A", chainName, "-j", "RETURN"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -548,8 +547,8 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputJumpReturnDrop) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. @@ -561,15 +560,15 @@ func (FilterInputJumpBuiltin) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil { +func (FilterInputJumpBuiltin) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil { return fmt.Errorf("iptables should be unable to jump to a built-in chain") } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error { +func (FilterInputJumpBuiltin) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -583,7 +582,7 @@ func (FilterInputJumpTwice) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { +func (FilterInputJumpTwice) ContainerAction(ip net.IP, ipv6 bool) error { const chainName2 = chainName + "2" rules := [][]string{ {"-P", "INPUT", "DROP"}, @@ -593,7 +592,7 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { {"-A", chainName, "-j", chainName2}, {"-A", "INPUT", "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -603,8 +602,8 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpTwice) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpTwice) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputDestination verifies that we can filter packets via `-d @@ -617,8 +616,8 @@ func (FilterInputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDestination) ContainerAction(ip net.IP) error { - addrs, err := localAddrs() +func (FilterInputDestination) ContainerAction(ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) if err != nil { return err } @@ -629,7 +628,7 @@ func (FilterInputDestination) ContainerAction(ip net.IP) error { for _, addr := range addrs { rules = append(rules, []string{"-A", "INPUT", "-d", addr, "-j", "ACCEPT"}) } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -637,8 +636,8 @@ func (FilterInputDestination) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDestination) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDestination) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputInvertDestination verifies that we can filter packets via `! -d @@ -651,14 +650,14 @@ func (FilterInputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { +func (FilterInputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ {"-P", "INPUT", "DROP"}, - {"-A", "INPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + {"-A", "INPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -666,8 +665,8 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertDestination) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputSource verifies that we can filter packets via `-s @@ -680,14 +679,14 @@ func (FilterInputSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSource) ContainerAction(ip net.IP) error { +func (FilterInputSource) ContainerAction(ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets from this // machine. rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-A", "INPUT", "-s", fmt.Sprintf("%v", ip), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -695,8 +694,8 @@ func (FilterInputSource) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputSource) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputSource) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // FilterInputInvertSource verifies that we can filter packets via `! -s @@ -709,14 +708,14 @@ func (FilterInputInvertSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertSource) ContainerAction(ip net.IP) error { +func (FilterInputInvertSource) ContainerAction(ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ {"-P", "INPUT", "DROP"}, - {"-A", "INPUT", "!", "-s", localIP, "-j", "ACCEPT"}, + {"-A", "INPUT", "!", "-s", localIP(ipv6), "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -724,6 +723,6 @@ func (FilterInputInvertSource) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertSource) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertSource) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index ba0d6fc29..c1d83b471 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -52,8 +52,8 @@ func (FilterOutputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { +func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } @@ -66,7 +66,7 @@ func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP) error { +func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) } @@ -84,8 +84,8 @@ func (FilterOutputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { +func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } @@ -98,7 +98,7 @@ func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP) error { +func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) } @@ -115,8 +115,8 @@ func (FilterOutputAcceptTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } @@ -125,7 +125,7 @@ func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP) error { +func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { return connectTCP(ip, acceptPort, sendloopDuration) } @@ -138,8 +138,8 @@ func (FilterOutputDropTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } @@ -152,7 +152,7 @@ func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPOwner) LocalAction(ip net.IP) error { +func (FilterOutputDropTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) } @@ -169,8 +169,8 @@ func (FilterOutputAcceptUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } @@ -179,7 +179,7 @@ func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP) error { +func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { // Listen for UDP packets on acceptPort. return listenUDP(acceptPort, sendloopDuration) } @@ -193,8 +193,8 @@ func (FilterOutputDropUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } @@ -203,7 +203,7 @@ func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropUDPOwner) LocalAction(ip net.IP) error { +func (FilterOutputDropUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { // Listen for UDP packets on dropPort. if err := listenUDP(dropPort, sendloopDuration); err == nil { return fmt.Errorf("packets should not be received") @@ -222,8 +222,8 @@ func (FilterOutputOwnerFail) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { +func (FilterOutputOwnerFail) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { return fmt.Errorf("Invalid argument") } @@ -231,7 +231,7 @@ func (FilterOutputOwnerFail) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputOwnerFail) LocalAction(ip net.IP) error { +func (FilterOutputOwnerFail) LocalAction(ip net.IP, ipv6 bool) error { // no-op. return nil } @@ -245,8 +245,8 @@ func (FilterOutputAcceptGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { +func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { return err } @@ -255,7 +255,7 @@ func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP) error { +func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { return connectTCP(ip, acceptPort, sendloopDuration) } @@ -268,8 +268,8 @@ func (FilterOutputDropGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { +func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { return err } @@ -282,7 +282,7 @@ func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropGIDOwner) LocalAction(ip net.IP) error { +func (FilterOutputDropGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -299,12 +299,12 @@ func (FilterOutputInvertGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -317,7 +317,7 @@ func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP) error { +func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -334,12 +334,12 @@ func (FilterOutputInvertUIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -348,7 +348,7 @@ func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP) error { +func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP, ipv6 bool) error { return connectTCP(ip, acceptPort, sendloopDuration) } @@ -362,12 +362,12 @@ func (FilterOutputInvertUIDAndGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP) error { +func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -380,7 +380,7 @@ func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP) error { +func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -398,12 +398,12 @@ func (FilterOutputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDestination) ContainerAction(ip net.IP) error { +func (FilterOutputDestination) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -411,7 +411,7 @@ func (FilterOutputDestination) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDestination) LocalAction(ip net.IP) error { +func (FilterOutputDestination) LocalAction(ip net.IP, ipv6 bool) error { return listenUDP(acceptPort, sendloopDuration) } @@ -425,12 +425,12 @@ func (FilterOutputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { +func (FilterOutputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { rules := [][]string{ - {"-A", "OUTPUT", "!", "-d", localIP, "-j", "ACCEPT"}, + {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, } - if err := filterTableRules(rules); err != nil { + if err := filterTableRules(ipv6, rules); err != nil { return err } @@ -438,7 +438,7 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertDestination) LocalAction(ip net.IP) error { +func (FilterOutputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { return listenUDP(acceptPort, sendloopDuration) } @@ -452,12 +452,12 @@ func (FilterOutputInterfaceAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error { +func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") } - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "ACCEPT"); err != nil { return err } @@ -465,7 +465,7 @@ func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceAccept) LocalAction(ip net.IP) error { +func (FilterOutputInterfaceAccept) LocalAction(ip net.IP, ipv6 bool) error { return listenUDP(acceptPort, sendloopDuration) } @@ -479,12 +479,12 @@ func (FilterOutputInterfaceDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error { +func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") } - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", ifname, "-j", "DROP"); err != nil { return err } @@ -492,7 +492,7 @@ func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceDrop) LocalAction(ip net.IP) error { +func (FilterOutputInterfaceDrop) LocalAction(ip net.IP, ipv6 bool) error { if err := listenUDP(acceptPort, sendloopDuration); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) } @@ -510,8 +510,8 @@ func (FilterOutputInterface) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterface) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { +func (FilterOutputInterface) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { return err } @@ -519,7 +519,7 @@ func (FilterOutputInterface) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterface) LocalAction(ip net.IP) error { +func (FilterOutputInterface) LocalAction(ip net.IP, ipv6 bool) error { return listenUDP(acceptPort, sendloopDuration) } @@ -533,8 +533,8 @@ func (FilterOutputInterfaceBeginsWith) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { +func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { return err } @@ -542,7 +542,7 @@ func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP) error { +func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP, ipv6 bool) error { if err := listenUDP(acceptPort, sendloopDuration); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) } @@ -560,8 +560,8 @@ func (FilterOutputInterfaceInvertDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { +func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { return err } @@ -574,7 +574,7 @@ func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP) error { +func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP, ipv6 bool) error { if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -592,8 +592,8 @@ func (FilterOutputInterfaceInvertAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error { - if err := filterTable("-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { +func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP, ipv6 bool) error { + if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { return err } @@ -602,6 +602,6 @@ func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP) error { +func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP, ipv6 bool) error { return connectTCP(ip, acceptPort, sendloopDuration) } diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go index 16cb4f4da..dfbd80cd1 100644 --- a/test/iptables/iptables.go +++ b/test/iptables/iptables.go @@ -40,10 +40,10 @@ type TestCase interface { // ContainerAction runs inside the container. It receives the IP of the // local process. - ContainerAction(ip net.IP) error + ContainerAction(ip net.IP, ipv6 bool) error // LocalAction runs locally. It receives the IP of the container. - LocalAction(ip net.IP) error + LocalAction(ip net.IP, ipv6 bool) error } // Tests maps test names to TestCase. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index 340f9426e..550b6198a 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -15,8 +15,10 @@ package iptables import ( + "context" "fmt" "net" + "reflect" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -33,12 +35,30 @@ import ( // Container output is logged to $TEST_UNDECLARED_OUTPUTS_DIR if it exists, or // to stderr. func singleTest(t *testing.T, test TestCase) { + for _, tc := range []bool{false, true} { + subtest := "IPv4" + if tc { + subtest = "IPv6" + } + t.Run(subtest, func(t *testing.T) { + iptablesTest(t, test, tc) + }) + } +} + +func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) } - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. + if ipv6 && dockerutil.Runtime() != "runc" { + t.Skip("gVisor ip6tables not yet implemented") + } // Create and start the container. opts := dockerutil.RunOpts{ @@ -46,12 +66,16 @@ func singleTest(t *testing.T, test TestCase) { CapAdd: []string{"NET_ADMIN"}, } d.CopyFiles(&opts, "/runner", "test/iptables/runner/runner") - if err := d.Spawn(opts, "/runner/runner", "-name", test.Name()); err != nil { + args := []string{"/runner/runner", "-name", test.Name()} + if ipv6 { + args = append(args, "-ipv6") + } + if err := d.Spawn(ctx, opts, args...); err != nil { t.Fatalf("docker run failed: %v", err) } // Get the container IP. - ip, err := d.FindIP() + ip, err := d.FindIP(ctx, ipv6) if err != nil { t.Fatalf("failed to get container IP: %v", err) } @@ -62,14 +86,14 @@ func singleTest(t *testing.T, test TestCase) { } // Run our side of the test. - if err := test.LocalAction(ip); err != nil { + if err := test.LocalAction(ip, ipv6); err != nil { t.Fatalf("LocalAction failed: %v", err) } // Wait for the final statement. This structure has the side effect // that all container logs will appear within the individual test // context. - if _, err := d.WaitForOutput(TerminalStatement, TestTimeout); err != nil { + if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil { t.Fatalf("test failed: %v", err) } } @@ -83,7 +107,7 @@ func sendIP(ip net.IP) error { // The container may not be listening when we first connect, so retry // upon error. cb := func() error { - c, err := net.DialTCP("tcp4", nil, &contAddr) + c, err := net.DialTCP("tcp", nil, &contAddr) conn = c return err } @@ -260,6 +284,13 @@ func TestNATPreRedirectTCPPort(t *testing.T) { singleTest(t, NATPreRedirectTCPPort{}) } +func TestNATPreRedirectTCPOutgoing(t *testing.T) { + singleTest(t, NATPreRedirectTCPOutgoing{}) +} + +func TestNATOutRedirectTCPIncoming(t *testing.T) { + singleTest(t, NATOutRedirectTCPIncoming{}) +} func TestNATOutRedirectUDPPort(t *testing.T) { singleTest(t, NATOutRedirectUDPPort{}) } @@ -315,3 +346,28 @@ func TestInputSource(t *testing.T) { func TestInputInvertSource(t *testing.T) { singleTest(t, FilterInputInvertSource{}) } + +func TestFilterAddrs(t *testing.T) { + tcs := []struct { + ipv6 bool + addrs []string + want []string + }{ + { + ipv6: false, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"192.168.0.1", "192.168.0.2"}, + }, + { + ipv6: true, + addrs: []string{"192.168.0.1", "192.168.0.2/24", "::1", "::2/128"}, + want: []string{"::1", "::2"}, + }, + } + + for _, tc := range tcs { + if got := filterAddrs(tc.addrs, tc.ipv6); !reflect.DeepEqual(got, tc.want) { + t.Errorf("%v with IPv6 %t: got %v, but wanted %v", tc.addrs, tc.ipv6, got, tc.want) + } + } +} diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 7146edbb9..ca80a4b5f 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -18,27 +18,29 @@ import ( "fmt" "net" "os/exec" + "strings" "time" "gvisor.dev/gvisor/pkg/test/testutil" ) -const iptablesBinary = "iptables" -const localIP = "127.0.0.1" - -// filterTable calls `iptables -t filter` with the given args. -func filterTable(args ...string) error { - return tableCmd("filter", args) +// filterTable calls `ip{6}tables -t filter` with the given args. +func filterTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "filter", args) } -// natTable calls `iptables -t nat` with the given args. -func natTable(args ...string) error { - return tableCmd("nat", args) +// natTable calls `ip{6}tables -t nat` with the given args. +func natTable(ipv6 bool, args ...string) error { + return tableCmd(ipv6, "nat", args) } -func tableCmd(table string, args []string) error { +func tableCmd(ipv6 bool, table string, args []string) error { args = append([]string{"-t", table}, args...) - cmd := exec.Command(iptablesBinary, args...) + binary := "iptables" + if ipv6 { + binary = "ip6tables" + } + cmd := exec.Command(binary, args...) if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out)) } @@ -46,18 +48,18 @@ func tableCmd(table string, args []string) error { } // filterTableRules is like filterTable, but runs multiple iptables commands. -func filterTableRules(argsList [][]string) error { - return tableRules("filter", argsList) +func filterTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "filter", argsList) } // natTableRules is like natTable, but runs multiple iptables commands. -func natTableRules(argsList [][]string) error { - return tableRules("nat", argsList) +func natTableRules(ipv6 bool, argsList [][]string) error { + return tableRules(ipv6, "nat", argsList) } -func tableRules(table string, argsList [][]string) error { +func tableRules(ipv6 bool, table string, argsList [][]string) error { for _, args := range argsList { - if err := tableCmd(table, args); err != nil { + if err := tableCmd(ipv6, table, args); err != nil { return err } } @@ -70,7 +72,7 @@ func listenUDP(port int, timeout time.Duration) error { localAddr := net.UDPAddr{ Port: port, } - conn, err := net.ListenUDP(network, &localAddr) + conn, err := net.ListenUDP("udp", &localAddr) if err != nil { return err } @@ -83,17 +85,42 @@ func listenUDP(port int, timeout time.Duration) error { // sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified // over a duration. func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - // Send packets for a few seconds. + conn, err := connectUDP(ip, port) + if err != nil { + return err + } + defer conn.Close() + loopUDP(conn, duration) + return nil +} + +// spawnUDPLoop works like sendUDPLoop, but returns immediately and sends +// packets in another goroutine. +func spawnUDPLoop(ip net.IP, port int, duration time.Duration) error { + conn, err := connectUDP(ip, port) + if err != nil { + return err + } + go func() { + defer conn.Close() + loopUDP(conn, duration) + }() + return nil +} + +func connectUDP(ip net.IP, port int) (net.Conn, error) { remote := net.UDPAddr{ IP: ip, Port: port, } - conn, err := net.DialUDP(network, nil, &remote) + conn, err := net.DialUDP("udp", nil, &remote) if err != nil { - return err + return nil, err } - defer conn.Close() + return conn, nil +} +func loopUDP(conn net.Conn, duration time.Duration) { to := time.After(duration) for timedOut := false; !timedOut; { // This may return an error (connection refused) if the remote @@ -108,8 +135,6 @@ func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { time.Sleep(200 * time.Millisecond) } } - - return nil } // listenTCP listens for connections on a TCP port. @@ -119,7 +144,7 @@ func listenTCP(port int, timeout time.Duration) error { } // Starts listening on port. - lConn, err := net.ListenTCP("tcp4", &localAddr) + lConn, err := net.ListenTCP("tcp", &localAddr) if err != nil { return err } @@ -157,17 +182,38 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { return nil } -// localAddrs returns a list of local network interface addresses. -func localAddrs() ([]string, error) { +// localAddrs returns a list of local network interface addresses. When ipv6 is +// true, only IPv6 addresses are returned. Otherwise only IPv4 addresses are +// returned. +func localAddrs(ipv6 bool) ([]string, error) { addrs, err := net.InterfaceAddrs() if err != nil { return nil, err } addrStrs := make([]string, 0, len(addrs)) for _, addr := range addrs { - addrStrs = append(addrStrs, addr.String()) + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr.String(), "/") + if len(parts) != 2 { + return nil, fmt.Errorf("bad interface address: %q", addr.String()) + } + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, addr.String()) + } } - return addrStrs, nil + return filterAddrs(addrStrs, ipv6), nil +} + +func filterAddrs(addrs []string, ipv6 bool) []string { + addrStrs := make([]string, 0, len(addrs)) + for _, addr := range addrs { + // Add only IPv4 or only IPv6 addresses. + parts := strings.Split(addr, "/") + if isIPv6 := net.ParseIP(parts[0]).To4() == nil; isIPv6 == ipv6 { + addrStrs = append(addrStrs, parts[0]) + } + } + return addrStrs } // getInterfaceName returns the name of the interface other than loopback. @@ -184,3 +230,17 @@ func getInterfaceName() (string, bool) { return ifname, ifname != "" } + +func localIP(ipv6 bool) string { + if ipv6 { + return "::1" + } + return "127.0.0.1" +} + +func nowhereIP(ipv6 bool) string { + if ipv6 { + return "2001:db8::1" + } + return "192.0.2.1" +} diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 5e54a3963..ac0d91bb2 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -28,6 +28,8 @@ const ( func init() { RegisterTestCase(NATPreRedirectUDPPort{}) RegisterTestCase(NATPreRedirectTCPPort{}) + RegisterTestCase(NATPreRedirectTCPOutgoing{}) + RegisterTestCase(NATOutRedirectTCPIncoming{}) RegisterTestCase(NATOutRedirectUDPPort{}) RegisterTestCase(NATOutRedirectTCPPort{}) RegisterTestCase(NATDropUDP{}) @@ -51,8 +53,8 @@ func (NATPreRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -64,8 +66,8 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // NATPreRedirectTCPPort tests that connections are redirected on specified ports. @@ -77,8 +79,8 @@ func (NATPreRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATPreRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } @@ -87,10 +89,60 @@ func (NATPreRedirectTCPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPPort) LocalAction(ip net.IP) error { +func (NATPreRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { return connectTCP(ip, dropPort, sendloopDuration) } +// NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't +// affected by PREROUTING connection tracking. +type NATPreRedirectTCPOutgoing struct{} + +// Name implements TestCase.Name. +func (NATPreRedirectTCPOutgoing) Name() string { + return "NATPreRedirectTCPOutgoing" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect all incoming TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return connectTCP(ip, acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP, ipv6 bool) error { + return listenTCP(acceptPort, sendloopDuration) +} + +// NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't +// affected by OUTPUT connection tracking. +type NATOutRedirectTCPIncoming struct{} + +// Name implements TestCase.Name. +func (NATOutRedirectTCPIncoming) Name() string { + return "NATOutRedirectTCPIncoming" +} + +// ContainerAction implements TestCase.ContainerAction. +func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP, ipv6 bool) error { + // Redirect all outgoing TCP traffic to a closed port. + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { + return err + } + + // Establish a connection to the host process. + return listenTCP(acceptPort, sendloopDuration) +} + +// LocalAction implements TestCase.LocalAction. +func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP, ipv6 bool) error { + return connectTCP(ip, acceptPort, sendloopDuration) +} + // NATOutRedirectUDPPort tests that packets are redirected to different port. type NATOutRedirectUDPPort struct{} @@ -100,13 +152,12 @@ func (NATOutRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectUDPPort) ContainerAction(ip net.IP) error { - dest := []byte{200, 0, 0, 1} - return loopbackTest(dest, "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +func (NATOutRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { + return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectUDPPort) LocalAction(ip net.IP) error { +func (NATOutRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -121,8 +172,8 @@ func (NATDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { +func (NATDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } @@ -134,8 +185,8 @@ func (NATDropUDP) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATDropUDP) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // NATAcceptAll tests that all UDP packets are accepted. @@ -147,8 +198,8 @@ func (NATAcceptAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { +func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { return err } @@ -160,8 +211,8 @@ func (NATAcceptAll) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATAcceptAll) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // NATOutRedirectIP uses iptables to select packets based on destination IP and @@ -174,14 +225,17 @@ func (NATOutRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectIP) ContainerAction(ip net.IP) error { +func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - dest := net.IP([]byte{200, 0, 0, 2}) - return loopbackTest(dest, "-A", "OUTPUT", "-d", dest.String(), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) + return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "-d", nowhereIP(ipv6), + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectIP) LocalAction(ip net.IP) error { +func (NATOutRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -196,15 +250,15 @@ func (NATOutDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutDontRedirectIP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "OUTPUT", "-d", localIP, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { +func (NATOutDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } return sendUDPLoop(ip, acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATOutDontRedirectIP) LocalAction(ip net.IP) error { +func (NATOutDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { return listenUDP(acceptPort, sendloopDuration) } @@ -217,15 +271,21 @@ func (NATOutRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectInvert) ContainerAction(ip net.IP) error { +func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - dest := []byte{200, 0, 0, 3} - destStr := "200.0.0.2" - return loopbackTest(dest, "-A", "OUTPUT", "!", "-d", destStr, "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) + dest := "192.0.2.2" + if ipv6 { + dest = "2001:db8::2" + } + return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + "-A", "OUTPUT", + "!", "-d", dest, + "-p", "udp", + "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectInvert) LocalAction(ip net.IP) error { +func (NATOutRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -240,8 +300,8 @@ func (NATPreRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectIP) ContainerAction(ip net.IP) error { - addrs, err := localAddrs() +func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { + addrs, err := localAddrs(ipv6) if err != nil { return err } @@ -250,15 +310,15 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP) error { for _, addr := range addrs { rules = append(rules, []string{"-A", "PREROUTING", "-p", "udp", "-d", addr, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)}) } - if err := natTableRules(rules); err != nil { + if err := natTableRules(ipv6, rules); err != nil { return err } return listenUDP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectIP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // NATPreDontRedirectIP tests that iptables matching with "-d" does not match @@ -271,16 +331,16 @@ func (NATPreDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { +func (NATPreDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } return listenUDP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATPreDontRedirectIP) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, acceptPort, sendloopDuration) } // NATPreRedirectInvert tests that iptables can match with "! -d". @@ -292,16 +352,16 @@ func (NATPreRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectInvert) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-p", "udp", "!", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATPreRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } return listenUDP(acceptPort, sendloopDuration) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectInvert) LocalAction(ip net.IP) error { - return sendUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { + return spawnUDPLoop(ip, dropPort, sendloopDuration) } // NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a @@ -314,15 +374,15 @@ func (NATRedirectRequiresProtocol) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP) error { - if err := natTable("-A", "PREROUTING", "-d", localIP, "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { +func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { return errors.New("expected an error using REDIRECT --to-ports without a protocol") } return nil } // LocalAction implements TestCase.LocalAction. -func (NATRedirectRequiresProtocol) LocalAction(ip net.IP) error { +func (NATRedirectRequiresProtocol) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } @@ -336,15 +396,14 @@ func (NATOutRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { - if err := natTable("-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { +func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { + if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } timeout := 20 * time.Second - dest := []byte{127, 0, 0, 1} localAddr := net.TCPAddr{ - IP: dest, + IP: net.ParseIP(localIP(ipv6)), Port: acceptPort, } @@ -372,7 +431,7 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPPort) LocalAction(ip net.IP) error { +func (NATOutRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { return nil } @@ -386,10 +445,10 @@ func (NATLoopbackSkipsPrerouting) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error { +func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. dest := []byte{127, 0, 0, 1} - if err := natTable("-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { + if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } @@ -407,15 +466,15 @@ func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP) error { } // LocalAction implements TestCase.LocalAction. -func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP) error { +func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { // No-op. return nil } // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. -func loopbackTest(dest net.IP, args ...string) error { - if err := natTable(args...); err != nil { +func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { + if err := natTable(ipv6, args...); err != nil { return err } sendCh := make(chan error) diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go index 6f77c0684..69d3ef121 100644 --- a/test/iptables/runner/main.go +++ b/test/iptables/runner/main.go @@ -24,7 +24,10 @@ import ( "gvisor.dev/gvisor/test/iptables" ) -var name = flag.String("name", "", "name of the test to run") +var ( + name = flag.String("name", "", "name of the test to run") + ipv6 = flag.Bool("ipv6", false, "whether the test utilizes ip6tables") +) func main() { flag.Parse() @@ -43,7 +46,7 @@ func main() { } // Run the test. - if err := test.ContainerAction(ip); err != nil { + if err := test.ContainerAction(ip, *ipv6); err != nil { log.Fatalf("Failed running test %q: %v", *name, err) } @@ -57,7 +60,7 @@ func getIP() (net.IP, error) { localAddr := net.TCPAddr{ Port: iptables.IPExchangePort, } - listener, err := net.ListenTCP("tcp4", &localAddr) + listener, err := net.ListenTCP("tcp", &localAddr) if err != nil { return net.IP{}, fmt.Errorf("failed listening for IP: %v", err) } diff --git a/test/packetdrill/defs.bzl b/test/packetdrill/defs.bzl index f499c177b..fc28ce9ba 100644 --- a/test/packetdrill/defs.bzl +++ b/test/packetdrill/defs.bzl @@ -26,7 +26,7 @@ def _packetdrill_test_impl(ctx): transitive_files = depset() if hasattr(ctx.attr._test_runner, "data_runfiles"): - transitive_files = depset(ctx.attr._test_runner.data_runfiles.files) + transitive_files = ctx.attr._test_runner.data_runfiles.files runfiles = ctx.runfiles( files = [test_runner] + ctx.files._init_script + ctx.files.scripts, transitive_files = transitive_files, @@ -60,11 +60,15 @@ _packetdrill_test = rule( implementation = _packetdrill_test_impl, ) -_PACKETDRILL_TAGS = ["local", "manual"] +PACKETDRILL_TAGS = [ + "local", + "manual", + "packetdrill", +] def packetdrill_linux_test(name, **kwargs): if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS + kwargs["tags"] = PACKETDRILL_TAGS _packetdrill_test( name = name, flags = ["--dut_platform", "linux"], @@ -73,7 +77,7 @@ def packetdrill_linux_test(name, **kwargs): def packetdrill_netstack_test(name, **kwargs): if "tags" not in kwargs: - kwargs["tags"] = _PACKETDRILL_TAGS + kwargs["tags"] = PACKETDRILL_TAGS _packetdrill_test( name = name, # This is the default runtime unless diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index a1a5c3612..29d4cc6fe 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -53,7 +53,10 @@ response_in6->set_flowinfo(ntohl(addr_in6->sin6_flowinfo)); response_in6->mutable_addr()->assign( reinterpret_cast<const char *>(&addr_in6->sin6_addr.s6_addr), 16); - response_in6->set_scope_id(ntohl(addr_in6->sin6_scope_id)); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + response_in6->set_scope_id(addr_in6->sin6_scope_id); return ::grpc::Status::OK; } } @@ -89,7 +92,10 @@ addr_in6->sin6_flowinfo = htonl(proto_in6.flowinfo()); proto_in6.addr().copy( reinterpret_cast<char *>(&addr_in6->sin6_addr.s6_addr), 16); - addr_in6->sin6_scope_id = htonl(proto_in6.scope_id()); + // sin6_scope_id is stored in host byte order. + // + // https://www.gnu.org/software/libc/manual/html_node/Internet-Address-Formats.html + addr_in6->sin6_scope_id = proto_in6.scope_id(); *addr_len = sizeof(*addr_in6); break; } diff --git a/test/packetimpact/netdevs/BUILD b/test/packetimpact/netdevs/BUILD index 422bb9b0c..8d1193fed 100644 --- a/test/packetimpact/netdevs/BUILD +++ b/test/packetimpact/netdevs/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package( licenses = ["notice"], @@ -13,3 +13,11 @@ go_library( "//pkg/tcpip/header", ], ) + +go_test( + name = "netdevs_test", + size = "small", + srcs = ["netdevs_test.go"], + library = ":netdevs", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], +) diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go index d2c9cfeaf..eecfe0730 100644 --- a/test/packetimpact/netdevs/netdevs.go +++ b/test/packetimpact/netdevs/netdevs.go @@ -19,6 +19,7 @@ import ( "fmt" "net" "regexp" + "strconv" "strings" "gvisor.dev/gvisor/pkg/tcpip" @@ -27,6 +28,7 @@ import ( // A DeviceInfo represents a network device. type DeviceInfo struct { + ID uint32 MAC net.HardwareAddr IPv4Addr net.IP IPv4Net *net.IPNet @@ -35,7 +37,7 @@ type DeviceInfo struct { } var ( - deviceLine = regexp.MustCompile(`^\s*\d+: (\w+)`) + deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`) linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) @@ -43,6 +45,11 @@ var ( // ParseDevices parses the output from `ip addr show` into a map from device // name to information about the device. +// +// Note: if multiple IPv6 addresses are assigned to a device, the last address +// displayed by `ip addr show` will be used. This is fine for packetimpact +// because we will always only have at most one IPv6 address assigned to each +// device. func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { var currentDevice string var currentInfo DeviceInfo @@ -52,8 +59,12 @@ func ParseDevices(cmdOutput string) (map[string]DeviceInfo, error) { if currentDevice != "" { deviceInfos[currentDevice] = currentInfo } - currentInfo = DeviceInfo{} - currentDevice = m[1] + id, err := strconv.ParseUint(m[1], 10, 32) + if err != nil { + return nil, fmt.Errorf("parsing device ID %s: %w", m[1], err) + } + currentInfo = DeviceInfo{ID: uint32(id)} + currentDevice = m[2] } else if m := linkLine.FindStringSubmatch(line); m != nil { mac, err := net.ParseMAC(m[1]) if err != nil { diff --git a/test/packetimpact/netdevs/netdevs_test.go b/test/packetimpact/netdevs/netdevs_test.go new file mode 100644 index 000000000..24ad12198 --- /dev/null +++ b/test/packetimpact/netdevs/netdevs_test.go @@ -0,0 +1,227 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package netdevs + +import ( + "fmt" + "net" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func mustParseMAC(s string) net.HardwareAddr { + mac, err := net.ParseMAC(s) + if err != nil { + panic(fmt.Sprintf("failed to parse test MAC %q: %s", s, err)) + } + return mac +} + +func TestParseDevices(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + want map[string]DeviceInfo + }{ + { + desc: "v4 and v6", + cmdOutput: ` +1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000 + link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00 + inet 127.0.0.1/8 scope host lo + valid_lft forever preferred_lft forever + inet6 ::1/128 scope host + valid_lft forever preferred_lft forever +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever + inet6 fe80::42:c0ff:fea8:902/64 scope link tentative + valid_lft forever preferred_lft forever +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 223.245.225.10/24 brd 223.245.225.255 scope global eth2 + valid_lft forever preferred_lft forever + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "lo": DeviceInfo{ + ID: 1, + MAC: mustParseMAC("00:00:00:00:00:00"), + IPv4Addr: net.IPv4(127, 0, 0, 1), + IPv4Net: &net.IPNet{ + IP: net.IPv4(127, 0, 0, 0), + Mask: net.CIDRMask(8, 32), + }, + IPv6Addr: net.ParseIP("::1"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("::1"), + Mask: net.CIDRMask(128, 128), + }, + }, + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:c0ff:fea8:902"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth1": DeviceInfo{ + ID: 2617, + MAC: mustParseMAC("02:42:da:33:13:0a"), + IPv4Addr: net.IPv4(218, 51, 19, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(218, 51, 19, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:daff:fe33:130a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv4Addr: net.IPv4(223, 245, 225, 10), + IPv4Net: &net.IPNet{ + IP: net.IPv4(223, 245, 225, 0), + Mask: net.CIDRMask(24, 32), + }, + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + { + desc: "v4 only", + cmdOutput: ` +2613: eth0@if2614: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:c0:a8:09:02 brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 192.168.9.2/24 brd 192.168.9.255 scope global eth0 + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth0": DeviceInfo{ + ID: 2613, + MAC: mustParseMAC("02:42:c0:a8:09:02"), + IPv4Addr: net.IPv4(192, 168, 9, 2), + IPv4Net: &net.IPNet{ + IP: net.IPv4(192, 168, 9, 0), + Mask: net.CIDRMask(24, 32), + }, + }, + }, + }, + { + desc: "v6 only", + cmdOutput: ` +2615: eth2@if2616: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:df:f5:e1:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet6 fe80::42:dfff:fef5:e10a/64 scope link tentative + valid_lft forever preferred_lft forever`, + want: map[string]DeviceInfo{ + "eth2": DeviceInfo{ + ID: 2615, + MAC: mustParseMAC("02:42:df:f5:e1:0a"), + IPv6Addr: net.ParseIP("fe80::42:dfff:fef5:e10a"), + IPv6Net: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + }, + }, + } { + t.Run(v.desc, func(t *testing.T) { + got, err := ParseDevices(v.cmdOutput) + if err != nil { + t.Errorf("ParseDevices(\n%s\n) got unexpected error: %s", v.cmdOutput, err) + } + if diff := cmp.Diff(v.want, got); diff != "" { + t.Errorf("ParseDevices(\n%s\n) got output diff (-want, +got):\n%s", v.cmdOutput, diff) + } + }) + } +} + +func TestParseDevicesErrors(t *testing.T) { + for _, v := range []struct { + desc string + cmdOutput string + }{ + { + desc: "invalid MAC addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a:ffffffff brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v4 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 1234.4321.424242.0/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid v6 addr", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10/24 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80:ffffffff::42:daff:fe33:130a/64 scope link tentative + valid_lft forever preferred_lft forever`, + }, + { + desc: "invalid CIDR missing prefixlen", + cmdOutput: ` +2617: eth1@if2618: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc noqueue state UP group default + link/ether 02:42:da:33:13:0a brd ff:ff:ff:ff:ff:ff link-netnsid 0 + inet 218.51.19.10 brd 218.51.19.255 scope global eth1 + valid_lft forever preferred_lft forever + inet6 fe80::42:daff:fe33:130a scope link tentative + valid_lft forever preferred_lft forever`, + }, + } { + t.Run(v.desc, func(t *testing.T) { + if _, err := ParseDevices(v.cmdOutput); err == nil { + t.Errorf("ParseDevices(\n%s\n) succeeded unexpectedly, want error", v.cmdOutput) + } + }) + } +} diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD index 0b68a760a..bad4f0183 100644 --- a/test/packetimpact/runner/BUILD +++ b/test/packetimpact/runner/BUILD @@ -16,5 +16,6 @@ go_test( deps = [ "//pkg/test/dockerutil", "//test/packetimpact/netdevs", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index ea66b9756..79b3c9162 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -20,12 +20,12 @@ def _packetimpact_test_impl(ctx): ]) ctx.actions.write(bench, bench_content, is_executable = True) - transitive_files = depset() + transitive_files = [] if hasattr(ctx.attr._test_runner, "data_runfiles"): - transitive_files = depset(ctx.attr._test_runner.data_runfiles.files) + transitive_files.append(ctx.attr._test_runner.data_runfiles.files) runfiles = ctx.runfiles( files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, - transitive_files = transitive_files, + transitive_files = depset(transitive = transitive_files), collect_default = True, collect_data = True, ) @@ -55,7 +55,11 @@ _packetimpact_test = rule( implementation = _packetimpact_test_impl, ) -PACKETIMPACT_TAGS = ["local", "manual"] +PACKETIMPACT_TAGS = [ + "local", + "manual", + "packetimpact", +] def packetimpact_linux_test( name, @@ -75,7 +79,7 @@ def packetimpact_linux_test( name = name + "_linux_test", testbench_binary = testbench_binary, flags = ["--dut_platform", "linux"] + expect_failure_flag, - tags = PACKETIMPACT_TAGS + ["packetimpact"], + tags = PACKETIMPACT_TAGS, **kwargs ) @@ -101,7 +105,7 @@ def packetimpact_netstack_test( # This is the default runtime unless # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag, - tags = PACKETIMPACT_TAGS + ["packetimpact"], + tags = PACKETIMPACT_TAGS, **kwargs ) @@ -121,7 +125,10 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure name = testbench_binary, size = size, pure = pure, - tags = PACKETIMPACT_TAGS, + tags = [ + "local", + "manual", + ], **kwargs ) packetimpact_linux_test( diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index 3cac5915f..74e1e6def 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -16,6 +16,7 @@ package packetimpact_test import ( + "context" "flag" "fmt" "io/ioutil" @@ -29,6 +30,7 @@ import ( "testing" "time" + "github.com/docker/docker/api/types/mount" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/packetimpact/netdevs" ) @@ -94,15 +96,16 @@ func TestOne(t *testing.T) { } } dockerutil.EnsureSupportedDockerVersion() + ctx := context.Background() // Create the networks needed for the test. One control network is needed for // the gRPC control packets and one test network on which to transmit the test // packets. - ctrlNet := dockerutil.NewDockerNetwork(logger("ctrlNet")) - testNet := dockerutil.NewDockerNetwork(logger("testNet")) - for _, dn := range []*dockerutil.DockerNetwork{ctrlNet, testNet} { + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { for { - if err := createDockerNetwork(dn); err != nil { + if err := createDockerNetwork(ctx, dn); err != nil { t.Log("creating docker network:", err) const wait = 100 * time.Millisecond t.Logf("sleeping %s and will try creating docker network again", wait) @@ -113,11 +116,19 @@ func TestOne(t *testing.T) { } break } - defer func(dn *dockerutil.DockerNetwork) { - if err := dn.Cleanup(); err != nil { + defer func(dn *dockerutil.Network) { + if err := dn.Cleanup(ctx); err != nil { t.Errorf("unable to cleanup container %s: %s", dn.Name, err) } }(dn) + // Sanity check. + inspect, err := dn.Inspect(ctx) + if err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + } tmpDir, err := ioutil.TempDir("", "container-output") @@ -128,42 +139,51 @@ func TestOne(t *testing.T) { const testOutputDir = "/tmp/testoutput" - runOpts := dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Extra: []string{"--sysctl", "net.ipv6.conf.all.disable_ipv6=0", "--rm", "-v", tmpDir + ":" + testOutputDir}, - Foreground: true, - } - // Create the Docker container for the DUT. - dut := dockerutil.MakeDocker(logger("dut")) + dut := dockerutil.MakeContainer(ctx, logger("dut")) if *dutPlatform == "linux" { - dut.Runtime = "" + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) + } + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, } const containerPosixServerBinary = "/packetimpact/posix_server" dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") - if err := dut.Create(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort); err != nil { - t.Fatalf("unable to create container %s: %s", dut.Name, err) + conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("unable to create container %s: %v", dut.Name, err) } - defer dut.CleanUp() + + defer dut.CleanUp(ctx) // Add ctrlNet as eth1 and testNet as eth2. const testNetDev = "eth2" - if err := addNetworks(dut, dutAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil { + if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { t.Fatal(err) } - if err := dut.Start(); err != nil { + if err := dut.Start(ctx); err != nil { t.Fatalf("unable to start container %s: %s", dut.Name, err) } - if _, err := dut.WaitForOutput("Server listening.*\n", 60*time.Second); err != nil { + if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) } - dutTestDevice, dutDeviceInfo, err := deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet)) + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) if err != nil { t.Fatal(err) } @@ -173,11 +193,11 @@ func TestOne(t *testing.T) { // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if // needed. if remoteIPv6 == nil { - if _, err := dut.Exec(dockerutil.RunOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) } // Now try again, to make sure that it worked. - _, dutDeviceInfo, err = deviceByIP(dut, addressInSubnet(dutAddr, *testNet.Subnet)) + _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) if err != nil { t.Fatal(err) } @@ -188,16 +208,19 @@ func TestOne(t *testing.T) { } // Create the Docker container for the testbench. - testbench := dockerutil.MakeDocker(logger("testbench")) - testbench.Runtime = "" // The testbench always runs on Linux. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) tbb := path.Base(*testbenchBinary) containerTestbenchBinary := "/packetimpact/" + tbb runOpts = dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Extra: []string{"--sysctl", "net.ipv6.conf.all.disable_ipv6=0", "--rm", "-v", tmpDir + ":" + testOutputDir}, - Foreground: true, + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{mount.Mount{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, } testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) @@ -227,33 +250,42 @@ func TestOne(t *testing.T) { } }() - if err := testbench.Create(runOpts, snifferArgs...); err != nil { + conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...) + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil { t.Fatalf("unable to create container %s: %s", testbench.Name, err) } - defer testbench.CleanUp() + defer testbench.CleanUp(ctx) // Add ctrlNet as eth1 and testNet as eth2. - if err := addNetworks(testbench, testbenchAddr, []*dockerutil.DockerNetwork{ctrlNet, testNet}); err != nil { + if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { t.Fatal(err) } - if err := testbench.Start(); err != nil { + if err := testbench.Start(ctx); err != nil { t.Fatalf("unable to start container %s: %s", testbench.Name, err) } // Kill so that it will flush output. - defer testbench.Exec(dockerutil.RunOpts{}, "killall", snifferArgs[0]) + defer func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }() - if _, err := testbench.WaitForOutput(snifferRegex, 60*time.Second); err != nil { + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) } // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it - // will issue a RST. To prevent this IPtables can be used to filter out all + // will issue an RST. To prevent this IPtables can be used to filter out all // incoming packets. The raw socket that packetimpact tests use will still see // everything. - if _, err := testbench.Exec(dockerutil.RunOpts{}, "iptables", "-A", "INPUT", "-i", testNetDev, "-j", "DROP"); err != nil { - t.Fatalf("unable to Exec iptables on container %s: %s", testbench.Name, err) + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } } // FIXME(b/156449515): Some piece of the system has a race. The old @@ -273,23 +305,41 @@ func TestOne(t *testing.T) { "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), "--remote_ipv6", remoteIPv6.String(), "--remote_mac", remoteMAC.String(), + "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID), "--device", testNetDev, "--dut_type", *dutPlatform, ) - _, err = testbench.Exec(dockerutil.RunOpts{}, testArgs...) - if !*expectFailure && err != nil { - t.Fatal("test failed:", err) - } - if *expectFailure && err == nil { - t.Fatal("test failure expected but the test succeeded, enable the test and mark the corresponding bug as fixed") + testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if (err != nil) != *expectFailure { + var dutLogs string + if logs, err := dut.Logs(ctx); err != nil { + dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } else { + dutLogs = logs + } + + t.Errorf(`test error: %v, expect failure: %t + +====== Begin of DUT Logs ====== + +%s + +====== End of DUT Logs ====== + +====== Begin of Testbench Logs ====== + +%s + +====== End of Testbench Logs ======`, + err, *expectFailure, dutLogs, testbenchLogs) } } -func addNetworks(d *dockerutil.Docker, addr net.IP, networks []*dockerutil.DockerNetwork) error { +func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { for _, dn := range networks { ip := addressInSubnet(addr, *dn.Subnet) // Connect to the network with the specified IP address. - if err := dn.Connect(d, "--ip", ip.String()); err != nil { + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) } } @@ -307,9 +357,9 @@ func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { return net.IP(octets) } -// makeDockerNetwork makes a randomly-named network that will start with the +// createDockerNetwork makes a randomly-named network that will start with the // namePrefix. The network will be a random /24 subnet. -func createDockerNetwork(n *dockerutil.DockerNetwork) error { +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { randSource := rand.NewSource(time.Now().UnixNano()) r1 := rand.New(randSource) // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. @@ -318,12 +368,12 @@ func createDockerNetwork(n *dockerutil.DockerNetwork) error { IP: ip, Mask: ip.DefaultMask(), } - return n.Create() + return n.Create(ctx) } // deviceByIP finds a deviceInfo and device name from an IP address. -func deviceByIP(d *dockerutil.Docker, ip net.IP) (string, netdevs.DeviceInfo, error) { - out, err := d.Exec(dockerutil.RunOpts{}, "ip", "addr", "show") +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") if err != nil { return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) } diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index d19ec07d4..5a0ee1367 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -23,8 +23,8 @@ go_library( "//pkg/usermem", "//test/packetimpact/netdevs", "//test/packetimpact/proto:posix_server_go_proto", - "@com_github_google_go-cmp//cmp:go_default_library", - "@com_github_google_go-cmp//cmp/cmpopts:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", "@com_github_mohae_deepcopy//:go_default_library", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//keepalive:go_default_library", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 8b4a4d905..3af5f83fd 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -41,16 +41,19 @@ func portFromSockaddr(sa unix.Sockaddr) (uint16, error) { return 0, fmt.Errorf("sockaddr type %T does not contain port", sa) } -// pickPort makes a new socket and returns the socket FD and port. The domain should be AF_INET or AF_INET6. The caller must close the FD when done with +// pickPort makes a new socket and returns the socket FD and port. The domain +// should be AF_INET or AF_INET6. The caller must close the FD when done with // the port if there is no error. -func pickPort(domain, typ int) (int, uint16, error) { - fd, err := unix.Socket(domain, typ, 0) +func pickPort(domain, typ int) (fd int, port uint16, err error) { + fd, err = unix.Socket(domain, typ, 0) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("creating socket: %w", err) } defer func() { if err != nil { - err = multierr.Append(err, unix.Close(fd)) + if cerr := unix.Close(fd); cerr != nil { + err = multierr.Append(err, fmt.Errorf("failed to close socket %d: %w", fd, cerr)) + } } }() var sa unix.Sockaddr @@ -60,22 +63,22 @@ func pickPort(domain, typ int) (int, uint16, error) { copy(sa4.Addr[:], net.ParseIP(LocalIPv4).To4()) sa = &sa4 case unix.AF_INET6: - var sa6 unix.SockaddrInet6 + sa6 := unix.SockaddrInet6{ZoneId: uint32(LocalInterfaceID)} copy(sa6.Addr[:], net.ParseIP(LocalIPv6).To16()) sa = &sa6 default: return -1, 0, fmt.Errorf("invalid domain %d, it should be one of unix.AF_INET or unix.AF_INET6", domain) } if err = unix.Bind(fd, sa); err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("binding to %+v: %w", sa, err) } sa, err = unix.Getsockname(fd) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("Getsocketname(%d): %w", fd, err) } - port, err := portFromSockaddr(sa) + port, err = portFromSockaddr(sa) if err != nil { - return -1, 0, err + return -1, 0, fmt.Errorf("extracting port from socket address %+v: %w", sa, err) } return fd, port, nil } @@ -378,7 +381,7 @@ var _ layerState = (*udpState)(nil) func newUDPState(domain int, out, in UDP) (*udpState, error) { portPickerFD, localPort, err := pickPort(domain, unix.SOCK_DGRAM) if err != nil { - return nil, err + return nil, fmt.Errorf("picking port: %w", err) } s := udpState{ out: UDP{SrcPort: &localPort}, @@ -426,7 +429,6 @@ type Connection struct { layerStates []layerState injector Injector sniffer Sniffer - t *testing.T } // Returns the default incoming frame against which to match. If received is @@ -459,7 +461,9 @@ func (conn *Connection) match(override, received Layers) bool { } // Close frees associated resources held by the Connection. -func (conn *Connection) Close() { +func (conn *Connection) Close(t *testing.T) { + t.Helper() + errs := multierr.Combine(conn.sniffer.close(), conn.injector.close()) for _, s := range conn.layerStates { if err := s.close(); err != nil { @@ -467,7 +471,7 @@ func (conn *Connection) Close() { } } if errs != nil { - conn.t.Fatalf("unable to close %+v: %s", conn, errs) + t.Fatalf("unable to close %+v: %s", conn, errs) } } @@ -479,7 +483,9 @@ func (conn *Connection) Close() { // overriden first. As an example, valid values of overrideLayers for a TCP- // over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and // [Ethernet, IPv4, TCP]. -func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers { +func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers { + t.Helper() + var layersToSend Layers for i, s := range conn.layerStates { layer := s.outgoing() @@ -488,7 +494,7 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // end. if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 { if err := layer.merge(overrideLayers[j]); err != nil { - conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) + t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err) } } layersToSend = append(layersToSend, layer) @@ -502,21 +508,25 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L // This method is useful for sending out-of-band control messages such as // ICMP packets, where it would not make sense to update the transport layer's // state using the ICMP header. -func (conn *Connection) SendFrameStateless(frame Layers) { +func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) } // SendFrame sends a frame on the wire and updates the state of all layers. -func (conn *Connection) SendFrame(frame Layers) { +func (conn *Connection) SendFrame(t *testing.T, frame Layers) { + t.Helper() + outBytes, err := frame.ToBytes() if err != nil { - conn.t.Fatalf("can't build outgoing packet: %s", err) + t.Fatalf("can't build outgoing packet: %s", err) } - conn.injector.Send(outBytes) + conn.injector.Send(t, outBytes) // frame might have nil values where the caller wanted to use default values. // sentFrame will have no nil values in it because it comes from parsing the @@ -525,7 +535,7 @@ func (conn *Connection) SendFrame(frame Layers) { // Update the state of each layer based on what was sent. for i, s := range conn.layerStates { if err := s.sent(sentFrame[i]); err != nil { - conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) + t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err) } } } @@ -535,18 +545,22 @@ func (conn *Connection) SendFrame(frame Layers) { // // Types defined with Connection as the underlying type should expose // type-safe versions of this method. -func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) { - conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...)) +func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) { + t.Helper() + + conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...)) } // recvFrame gets the next successfully parsed frame (of type Layers) within the // timeout provided. If no parsable frame arrives before the timeout, it returns // nil. -func (conn *Connection) recvFrame(timeout time.Duration) Layers { +func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers { + t.Helper() + if timeout <= 0 { return nil } - b := conn.sniffer.Recv(timeout) + b := conn.sniffer.Recv(t, timeout) if b == nil { return nil } @@ -566,32 +580,36 @@ func (e *layersError) Error() string { // Expect expects a frame with the final layerStates layer matching the // provided Layer within the timeout specified. If it doesn't arrive in time, // an error is returned. -func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) { +func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) { + t.Helper() + // Make a frame that will ignore all but the final layer. layers := make([]Layer, len(conn.layerStates)) layers[len(layers)-1] = layer - gotFrame, err := conn.ExpectFrame(layers, timeout) + gotFrame, err := conn.ExpectFrame(t, layers, timeout) if err != nil { return nil, err } if len(conn.layerStates)-1 < len(gotFrame) { return gotFrame[len(conn.layerStates)-1], nil } - conn.t.Fatal("the received frame should be at least as long as the expected layers") + t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame) panic("unreachable") } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If one arrives in time, the Layers is returned without an // error. If it doesn't arrive in time, it returns nil and error is non-nil. -func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) { +func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) { + t.Helper() + deadline := time.Now().Add(timeout) var errs error for { var gotLayers Layers if timeout = time.Until(deadline); timeout > 0 { - gotLayers = conn.recvFrame(timeout) + gotLayers = conn.recvFrame(t, timeout) } if gotLayers == nil { if errs == nil { @@ -602,7 +620,7 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { if err := s.received(gotLayers[i]); err != nil { - conn.t.Fatal(err) + t.Fatalf("failed to update test connection's layer states based on received frame: %s", err) } } return gotLayers, nil @@ -613,8 +631,10 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *Connection) Drain() { - conn.sniffer.Drain() +func (conn *Connection) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection. @@ -622,6 +642,8 @@ type TCPIPv4 Connection // NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults. func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -647,57 +669,58 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 { layerStates: []layerState{etherState, ipv4State, tcpState}, injector: injector, sniffer: sniffer, - t: t, } } // Connect performs a TCP 3-way handshake. The input Connection should have a // final TCP Layer. -func (conn *TCPIPv4) Connect() { - conn.t.Helper() +func (conn *TCPIPv4) Connect(t *testing.T) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ConnectWithOptions performs a TCP 3-way handshake with given TCP options. // The input Connection should have a final TCP Layer. -func (conn *TCPIPv4) ConnectWithOptions(options []byte) { - conn.t.Helper() +func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { + t.Helper() // Send the SYN. - conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { - conn.t.Fatalf("didn't get synack during handshake: %s", err) + t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // ExpectNextData attempts to receive the next incoming segment for the @@ -706,9 +729,11 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio // It differs from ExpectData() in that here we are only interested in the next // received segment, while ExpectData() can receive multiple segments for the // connection until there is a match with given layers or a timeout. -func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { +func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + // Receive the first incoming TCP segment for this connection. - got, err := conn.ExpectData(&TCP{}, nil, timeout) + got, err := conn.ExpectData(t, &TCP{}, nil, timeout) if err != nil { return nil, err } @@ -717,7 +742,7 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur expected[len(expected)-1] = tcp if payload != nil { expected = append(expected, payload) - tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length())) + tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length())) } if !(*Connection)(conn).match(expected, got) { return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got) @@ -727,71 +752,91 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur // Send a packet with reasonable defaults. Potentially override the TCP layer in // the connection with the provided layer and add additionLayers. -func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&tcp}, additionalLayers...) +func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...) } // Close frees associated resources held by the TCPIPv4 connection. -func (conn *TCPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *TCPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Expect expects a frame with the TCP layer matching the provided TCP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) { - layer, err := (*Connection)(conn).Expect(&tcp, timeout) +func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &tcp, timeout) if layer == nil { return nil, err } gotTCP, ok := layer.(*TCP) if !ok { - conn.t.Fatalf("expected %s to be TCP", layer) + t.Fatalf("expected %s to be TCP", layer) } return gotTCP, err } -func (conn *TCPIPv4) tcpState() *tcpState { +func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState { + t.Helper() + state, ok := conn.layerStates[2].(*tcpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2]) } return state } -func (conn *TCPIPv4) ipv4State() *ipv4State { +func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // RemoteSeqNum returns the next expected sequence number from the DUT. -func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value { - return conn.tcpState().remoteSeqNum +func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).remoteSeqNum } // LocalSeqNum returns the next sequence number to send from the testbench. -func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value { - return conn.tcpState().localSeqNum +func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value { + t.Helper() + + return conn.tcpState(t).localSeqNum } // SynAck returns the SynAck that was part of the handshake. -func (conn *TCPIPv4) SynAck() *TCP { - return conn.tcpState().synAck +func (conn *TCPIPv4) SynAck(t *testing.T) *TCP { + t.Helper() + + return conn.tcpState(t).synAck } // LocalAddr gets the local socket address of this connection. -func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *TCPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *TCPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) } // IPv6Conn maintains the state for all the layers in a IPv6 connection. @@ -799,6 +844,8 @@ type IPv6Conn Connection // NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults. func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make EtherState: %s", err) @@ -821,25 +868,30 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn { layerStates: []layerState{etherState, ipv6State}, injector: injector, sniffer: sniffer, - t: t, } } // Send sends a frame with ipv6 overriding the IPv6 layer defaults and // additionalLayers added after it. -func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...) +func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...) } // Close to clean up any resources held. -func (conn *IPv6Conn) Close() { - (*Connection)(conn).Close() +func (conn *IPv6Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // ExpectFrame expects a frame that matches the provided Layers within the // timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) { - return (*Connection)(conn).ExpectFrame(frame, timeout) +func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(conn).ExpectFrame(t, frame, timeout) } // UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection. @@ -847,6 +899,8 @@ type UDPIPv4 Connection // NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults. func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { + t.Helper() + etherState, err := newEtherState(Ether{}, Ether{}) if err != nil { t.Fatalf("can't make etherState: %s", err) @@ -872,79 +926,280 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 { layerStates: []layerState{etherState, ipv4State, udpState}, injector: injector, sniffer: sniffer, - t: t, } } -func (conn *UDPIPv4) udpState() *udpState { +func (conn *UDPIPv4) udpState(t *testing.T) *udpState { + t.Helper() + state, ok := conn.layerStates[2].(*udpState) if !ok { - conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) } return state } -func (conn *UDPIPv4) ipv4State() *ipv4State { +func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State { + t.Helper() + state, ok := conn.layerStates[1].(*ipv4State) if !ok { - conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) + t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1]) } return state } // LocalAddr gets the local socket address of this connection. -func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 { - sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)} - copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr) +func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { + t.Helper() + + sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)} + copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr) return sa } // Send sends a packet with reasonable defaults, potentially overriding the UDP // layer and adding additionLayers. -func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&udp}, additionalLayers...) +func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) } // SendIP sends a packet with reasonable defaults, potentially overriding the // UDP and IPv4 headers and adding additionLayers. -func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) { - (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...) +func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) } // Expect expects a frame with the UDP layer matching the provided UDP within // the timeout specified. If it doesn't arrive in time, an error is returned. -func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) { - conn.t.Helper() - layer, err := (*Connection)(conn).Expect(&udp, timeout) - if layer == nil { +func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { return nil, err } gotUDP, ok := layer.(*UDP) if !ok { - conn.t.Fatalf("expected %s to be UDP", layer) + t.Fatalf("expected %s to be UDP", layer) } - return gotUDP, err + return gotUDP, nil } // ExpectData is a convenient method that expects a Layer and the Layer after // it. If it doens't arrive in time, it returns nil. -func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) { - conn.t.Helper() +func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + expected := make([]Layer, len(conn.layerStates)) expected[len(expected)-1] = &udp if payload.length() != 0 { expected = append(expected, &payload) } - return (*Connection)(conn).ExpectFrame(expected, timeout) + return (*Connection)(conn).ExpectFrame(t, expected, timeout) } // Close frees associated resources held by the UDPIPv4 connection. -func (conn *UDPIPv4) Close() { - (*Connection)(conn).Close() +func (conn *UDPIPv4) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) +} + +// Drain drains the sniffer's receive buffer by receiving packets until there's +// nothing else to receive. +func (conn *UDPIPv4) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection. +type UDPIPv6 Connection + +// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults. +func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make IPv6State: %s", err) + } + udpState, err := newUDPState(unix.AF_INET6, outgoingUDP, incomingUDP) + if err != nil { + t.Fatalf("can't make udpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + return UDPIPv6{ + layerStates: []layerState{etherState, ipv6State, udpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *UDPIPv6) udpState(t *testing.T) *udpState { + t.Helper() + + state, ok := conn.layerStates[2].(*udpState) + if !ok { + t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2]) + } + return state +} + +func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State { + t.Helper() + + state, ok := conn.layerStates[1].(*ipv6State) + if !ok { + t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1]) + } + return state +} + +// LocalAddr gets the local socket address of this connection. +func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 { + t.Helper() + + sa := &unix.SockaddrInet6{ + Port: int(*conn.udpState(t).out.SrcPort), + // Local address is in perspective to the remote host, so it's scoped to the + // ID of the remote interface. + ZoneId: uint32(RemoteInterfaceID), + } + copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr) + return sa +} + +// Send sends a packet with reasonable defaults, potentially overriding the UDP +// layer and adding additionLayers. +func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...) +} + +// SendIPv6 sends a packet with reasonable defaults, potentially overriding the +// UDP and IPv6 headers and adding additionLayers. +func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...) +} + +// Expect expects a frame with the UDP layer matching the provided UDP within +// the timeout specified. If it doesn't arrive in time, an error is returned. +func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) { + t.Helper() + + layer, err := (*Connection)(conn).Expect(t, &udp, timeout) + if err != nil { + return nil, err + } + gotUDP, ok := layer.(*UDP) + if !ok { + t.Fatalf("expected %s to be UDP", layer) + } + return gotUDP, nil +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = &udp + if payload.length() != 0 { + expected = append(expected, &payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the UDPIPv6 connection. +func (conn *UDPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } // Drain drains the sniffer's receive buffer by receiving packets until there's // nothing else to receive. -func (conn *UDPIPv4) Drain() { - conn.sniffer.Drain() +func (conn *UDPIPv6) Drain(t *testing.T) { + t.Helper() + + conn.sniffer.Drain(t) +} + +// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection. +type TCPIPv6 Connection + +// NewTCPIPv6 creates a new TCPIPv6 connection with reasonable defaults. +func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 { + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make etherState: %s", err) + } + ipv6State, err := newIPv6State(IPv6{}, IPv6{}) + if err != nil { + t.Fatalf("can't make ipv6State: %s", err) + } + tcpState, err := newTCPState(unix.AF_INET6, outgoingTCP, incomingTCP) + if err != nil { + t.Fatalf("can't make tcpState: %s", err) + } + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return TCPIPv6{ + layerStates: []layerState{etherState, ipv6State, tcpState}, + injector: injector, + sniffer: sniffer, + } +} + +func (conn *TCPIPv6) SrcPort() uint16 { + state := conn.layerStates[2].(*tcpState) + return *state.out.SrcPort +} + +// ExpectData is a convenient method that expects a Layer and the Layer after +// it. If it doens't arrive in time, it returns nil. +func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) { + t.Helper() + + expected := make([]Layer, len(conn.layerStates)) + expected[len(expected)-1] = tcp + if payload != nil { + expected = append(expected, payload) + } + return (*Connection)(conn).ExpectFrame(t, expected, timeout) +} + +// Close frees associated resources held by the TCPIPv6 connection. +func (conn *TCPIPv6) Close(t *testing.T) { + t.Helper() + + (*Connection)(conn).Close(t) } diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 2a2afecb5..73c532e75 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -31,13 +31,14 @@ import ( // DUT communicates with the DUT to force it to make POSIX calls. type DUT struct { - t *testing.T conn *grpc.ClientConn posixServer POSIXClient } // NewDUT creates a new connection with the DUT over gRPC. func NewDUT(t *testing.T) DUT { + t.Helper() + flag.Parse() if err := genPseudoFlags(); err != nil { t.Fatal("generating psuedo flags:", err) @@ -50,7 +51,6 @@ func NewDUT(t *testing.T) DUT { } posixServer := NewPOSIXClient(conn) return DUT{ - t: t, conn: conn, posixServer: posixServer, } @@ -61,8 +61,9 @@ func (dut *DUT) TearDown() { dut.conn.Close() } -func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { - dut.t.Helper() +func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr { + t.Helper() + switch s := sa.(type) { case *unix.SockaddrInet4: return &pb.Sockaddr{ @@ -87,12 +88,13 @@ func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr { }, } } - dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + t.Fatalf("can't parse Sockaddr struct: %+v", sa) return nil } -func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr { + t.Helper() + switch s := sa.Sockaddr.(type) { case *pb.Sockaddr_In: ret := unix.SockaddrInet4{ @@ -106,31 +108,34 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr { ZoneId: s.In6.GetScopeId(), } copy(ret.Addr[:], s.In6.GetAddr()) + return &ret } - dut.t.Fatalf("can't parse Sockaddr: %+v", sa) + t.Fatalf("can't parse Sockaddr proto: %#v", sa) return nil } // CreateBoundSocket makes a new socket on the DUT, with type typ and protocol // proto, and bound to the IP address addr. Returns the new file descriptor and // the port that was selected on the DUT. -func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) { - dut.t.Helper() +func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) { + t.Helper() + var fd int32 if addr.To4() != nil { - fd = dut.Socket(unix.AF_INET, typ, proto) + fd = dut.Socket(t, unix.AF_INET, typ, proto) sa := unix.SockaddrInet4{} copy(sa.Addr[:], addr.To4()) - dut.Bind(fd, &sa) + dut.Bind(t, fd, &sa) } else if addr.To16() != nil { - fd = dut.Socket(unix.AF_INET6, typ, proto) + fd = dut.Socket(t, unix.AF_INET6, typ, proto) sa := unix.SockaddrInet6{} copy(sa.Addr[:], addr.To16()) - dut.Bind(fd, &sa) + sa.ZoneId = uint32(RemoteInterfaceID) + dut.Bind(t, fd, &sa) } else { - dut.t.Fatalf("unknown ip addr type for remoteIP") + t.Fatalf("invalid IP address: %s", addr) } - sa := dut.GetSockName(fd) + sa := dut.GetSockName(t, fd) var port int switch s := sa.(type) { case *unix.SockaddrInet4: @@ -138,15 +143,17 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) case *unix.SockaddrInet6: port = s.Port default: - dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa) + t.Fatalf("unknown sockaddr type from getsockname: %T", sa) } return fd, uint16(port) } // CreateListener makes a new TCP connection. If it fails, the test ends. -func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { - fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4)) - dut.Listen(fd, backlog) +func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) { + t.Helper() + + fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4)) + dut.Listen(t, fd, backlog) return fd, remotePort } @@ -156,53 +163,57 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) { // Accept calls accept on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // AcceptWithErrno. -func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - fd, sa, err := dut.AcceptWithErrno(ctx, sockfd) + fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd) if fd < 0 { - dut.t.Fatalf("failed to accept: %s", err) + t.Fatalf("failed to accept: %s", err) } return fd, sa } // AcceptWithErrno calls accept on the DUT. -func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.AcceptRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.Accept(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Accept: %s", err) + t.Fatalf("failed to call Accept: %s", err) } - return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } // Bind calls bind on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is // needed, use BindWithErrno. -func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.BindWithErrno(ctx, fd, sa) + ret, err := dut.BindWithErrno(ctx, t, fd, sa) if ret != 0 { - dut.t.Fatalf("failed to bind socket: %s", err) + t.Fatalf("failed to bind socket: %s", err) } } // BindWithErrno calls bind on the DUT. -func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.BindRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Bind(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -210,25 +221,27 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) ( // Close calls close on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // CloseWithErrno. -func (dut *DUT) Close(fd int32) { - dut.t.Helper() +func (dut *DUT) Close(t *testing.T, fd int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.CloseWithErrno(ctx, fd) + ret, err := dut.CloseWithErrno(ctx, t, fd) if ret != 0 { - dut.t.Fatalf("failed to close: %s", err) + t.Fatalf("failed to close: %s", err) } } // CloseWithErrno calls close on the DUT. -func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) { + t.Helper() + req := pb.CloseRequest{ Fd: fd, } resp, err := dut.posixServer.Close(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Close: %s", err) + t.Fatalf("failed to call Close: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -236,28 +249,30 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) { // Connect calls connect on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use ConnectWithErrno. -func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) { - dut.t.Helper() +func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ConnectWithErrno(ctx, fd, sa) + ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) // Ignore 'operation in progress' error that can be returned when the socket // is non-blocking. if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { - dut.t.Fatalf("failed to connect socket: %s", err) + t.Fatalf("failed to connect socket: %s", err) } } // ConnectWithErrno calls bind on the DUT. -func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.ConnectRequest{ Sockfd: fd, - Addr: dut.sockaddrToProto(sa), + Addr: dut.sockaddrToProto(t, sa), } resp, err := dut.posixServer.Connect(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Connect: %s", err) + t.Fatalf("failed to call Connect: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -265,20 +280,22 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr // Fcntl calls fcntl on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use FcntlWithErrno. -func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 { - dut.t.Helper() +func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg) + ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg) if ret == -1 { - dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) + t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err) } return ret } // FcntlWithErrno calls fcntl on the DUT. -func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) { + t.Helper() + req := pb.FcntlRequest{ Fd: fd, Cmd: cmd, @@ -286,7 +303,7 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, } resp, err := dut.posixServer.Fcntl(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Fcntl: %s", err) + t.Fatalf("failed to call Fcntl: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -294,32 +311,35 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, // GetSockName calls getsockname on the DUT and causes a fatal test failure if // it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockNameWithErrno. -func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr { - dut.t.Helper() +func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd) + ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd) if ret != 0 { - dut.t.Fatalf("failed to getsockname: %s", err) + t.Fatalf("failed to getsockname: %s", err) } return sa } // GetSockNameWithErrno calls getsockname on the DUT. -func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) { - dut.t.Helper() +func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) { + t.Helper() + req := pb.GetSockNameRequest{ Sockfd: sockfd, } resp, err := dut.posixServer.GetSockName(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Bind: %s", err) + t.Fatalf("failed to call Bind: %s", err) } - return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) } -func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { - dut.t.Helper() +func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { + t.Helper() + req := pb.GetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -329,11 +349,11 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i } resp, err := dut.posixServer.GetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call GetSockOpt: %s", err) + t.Fatalf("failed to call GetSockOpt: %s", err) } optval := resp.GetOptval() if optval == nil { - dut.t.Fatalf("GetSockOpt response does not contain a value") + t.Fatalf("GetSockOpt response does not contain a value") } return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) } @@ -343,13 +363,14 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i // needed, use GetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific GetSockOptXxx function. -func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { - dut.t.Helper() +func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen) + ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen) if ret != 0 { - dut.t.Fatalf("failed to GetSockOpt: %s", err) + t.Fatalf("failed to GetSockOpt: %s", err) } return optval } @@ -357,12 +378,13 @@ func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte { // GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific GetSockOptXxxWithErrno function. -func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) +func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES) bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval) + t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val) } return ret, bytesval.Bytesval, errno } @@ -370,24 +392,26 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, // GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use GetSockOptIntWithErrno. -func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 { - dut.t.Helper() +func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname) + ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptInt: %s", err) + t.Fatalf("failed to GetSockOptInt: %s", err) } return intval } // GetSockOptIntWithErrno calls getsockopt with an integer optval. -func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) +func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT) intval, ok := optval.Val.(*pb.SockOptVal_Intval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval) + t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val) } return ret, intval.Intval, errno } @@ -395,24 +419,26 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna // GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use GetSockOptTimevalWithErrno. -func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval { - dut.t.Helper() +func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname) + ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname) if ret != 0 { - dut.t.Fatalf("failed to GetSockOptTimeval: %s", err) + t.Fatalf("failed to GetSockOptTimeval: %s", err) } return timeval } // GetSockOptTimevalWithErrno calls getsockopt and returns a timeval. -func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) { - dut.t.Helper() - ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) +func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) { + t.Helper() + + ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME) tv, ok := optval.Val.(*pb.SockOptVal_Timeval) if !ok { - dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval) + t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val) } timeval := unix.Timeval{ Sec: tv.Timeval.Seconds, @@ -424,26 +450,28 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o // Listen calls listen on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // ListenWithErrno. -func (dut *DUT) Listen(sockfd, backlog int32) { - dut.t.Helper() +func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.ListenWithErrno(ctx, sockfd, backlog) + ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog) if ret != 0 { - dut.t.Fatalf("failed to listen: %s", err) + t.Fatalf("failed to listen: %s", err) } } // ListenWithErrno calls listen on the DUT. -func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) { + t.Helper() + req := pb.ListenRequest{ Sockfd: sockfd, Backlog: backlog, } resp, err := dut.posixServer.Listen(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Listen: %s", err) + t.Fatalf("failed to call Listen: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -451,20 +479,22 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int // Send calls send on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendWithErrno. -func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 { - dut.t.Helper() +func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags) + ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags) if ret == -1 { - dut.t.Fatalf("failed to send: %s", err) + t.Fatalf("failed to send: %s", err) } return ret } // SendWithErrno calls send on the DUT. -func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) { + t.Helper() + req := pb.SendRequest{ Sockfd: sockfd, Buf: buf, @@ -472,7 +502,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla } resp, err := dut.posixServer.Send(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Send: %s", err) + t.Fatalf("failed to call Send: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -480,48 +510,52 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla // SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // SendToWithErrno. -func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { - dut.t.Helper() +func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr) + ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr) if ret == -1 { - dut.t.Fatalf("failed to sendto: %s", err) + t.Fatalf("failed to sendto: %s", err) } return ret } // SendToWithErrno calls sendto on the DUT. -func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { - dut.t.Helper() +func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) { + t.Helper() + req := pb.SendToRequest{ Sockfd: sockfd, Buf: buf, Flags: flags, - DestAddr: dut.sockaddrToProto(destAddr), + DestAddr: dut.sockaddrToProto(t, destAddr), } resp, err := dut.posixServer.SendTo(ctx, &req) if err != nil { - dut.t.Fatalf("faled to call SendTo: %s", err) + t.Fatalf("faled to call SendTo: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } // SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking // is true, otherwise it will clear the flag. -func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) { - dut.t.Helper() - flags := dut.Fcntl(fd, unix.F_GETFL, 0) +func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { + t.Helper() + + flags := dut.Fcntl(t, fd, unix.F_GETFL, 0) if nonblocking { flags |= unix.O_NONBLOCK } else { flags &= ^unix.O_NONBLOCK } - dut.Fcntl(fd, unix.F_SETFL, flags) + dut.Fcntl(t, fd, unix.F_SETFL, flags) } -func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { - dut.t.Helper() +func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) { + t.Helper() + req := pb.SetSockOptRequest{ Sockfd: sockfd, Level: level, @@ -530,7 +564,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op } resp, err := dut.posixServer.SetSockOpt(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call SetSockOpt: %s", err) + t.Fatalf("failed to call SetSockOpt: %s", err) } return resp.GetRet(), syscall.Errno(resp.GetErrno_()) } @@ -540,81 +574,89 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op // needed, use SetSockOptWithErrno. Because endianess and the width of values // might differ between the testbench and DUT architectures, prefer to use a // more specific SetSockOptXxx function. -func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) { - dut.t.Helper() +func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOpt: %s", err) + t.Fatalf("failed to SetSockOpt: %s", err) } } // SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the // width of values might differ between the testbench and DUT architectures, // prefer to use a more specific SetSockOptXxxWithErrno function. -func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) +func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}}) } // SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the int optval or error handling // is needed, use SetSockOptIntWithErrno. -func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) { - dut.t.Helper() +func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval) + ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptInt: %s", err) + t.Fatalf("failed to SetSockOptInt: %s", err) } } // SetSockOptIntWithErrno calls setsockopt with an integer optval. -func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) { - dut.t.Helper() - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) +func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) { + t.Helper() + + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}}) } // SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure // if it doesn't succeed. If more control over the timeout or error handling is // needed, use SetSockOptTimevalWithErrno. -func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv) + ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv) if ret != 0 { - dut.t.Fatalf("failed to SetSockOptTimeval: %s", err) + t.Fatalf("failed to SetSockOptTimeval: %s", err) } } // SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to // bytes. -func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { - dut.t.Helper() +func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) { + t.Helper() + timeval := pb.Timeval{ Seconds: int64(tv.Sec), Microseconds: int64(tv.Usec), } - return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) + return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}}) } // Socket calls socket on the DUT and returns the file descriptor. If socket // fails on the DUT, the test ends. -func (dut *DUT) Socket(domain, typ, proto int32) int32 { - dut.t.Helper() - fd, err := dut.SocketWithErrno(domain, typ, proto) +func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 { + t.Helper() + + fd, err := dut.SocketWithErrno(t, domain, typ, proto) if fd < 0 { - dut.t.Fatalf("failed to create socket: %s", err) + t.Fatalf("failed to create socket: %s", err) } return fd } // SocketWithErrno calls socket on the DUT and returns the fd and errno. -func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { - dut.t.Helper() +func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) { + t.Helper() + req := pb.SocketRequest{ Domain: domain, Type: typ, @@ -623,7 +665,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { ctx := context.Background() resp, err := dut.posixServer.Socket(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Socket: %s", err) + t.Fatalf("failed to call Socket: %s", err) } return resp.GetFd(), syscall.Errno(resp.GetErrno_()) } @@ -631,20 +673,22 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) { // Recv calls recv on the DUT and causes a fatal test failure if it doesn't // succeed. If more control over the timeout or error handling is needed, use // RecvWithErrno. -func (dut *DUT) Recv(sockfd, len, flags int32) []byte { - dut.t.Helper() +func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags) + ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags) if ret == -1 { - dut.t.Fatalf("failed to recv: %s", err) + t.Fatalf("failed to recv: %s", err) } return buf } // RecvWithErrno calls recv on the DUT. -func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) { - dut.t.Helper() +func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) { + t.Helper() + req := pb.RecvRequest{ Sockfd: sockfd, Len: len, @@ -652,7 +696,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (in } resp, err := dut.posixServer.Recv(ctx, &req) if err != nil { - dut.t.Fatalf("failed to call Recv: %s", err) + t.Fatalf("failed to call Recv: %s", err) } return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 560c4111b..24aa46cce 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -15,6 +15,7 @@ package testbench import ( + "encoding/binary" "encoding/hex" "fmt" "reflect" @@ -470,17 +471,11 @@ func (l *IPv6) ToBytes() ([]byte, error) { if l.NextHeader != nil { fields.NextHeader = *l.NextHeader } else { - switch n := l.next().(type) { - case *TCP: - fields.NextHeader = uint8(header.TCPProtocolNumber) - case *UDP: - fields.NextHeader = uint8(header.UDPProtocolNumber) - case *ICMPv6: - fields.NextHeader = uint8(header.ICMPv6ProtocolNumber) - default: - // TODO(b/150301488): Support more protocols as needed. - return nil, fmt.Errorf("ToBytes can't deduce the IPv6 header's next protocol: %#v", n) + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err } + fields.NextHeader = nh } if l.HopLimit != nil { fields.HopLimit = *l.HopLimit @@ -495,6 +490,27 @@ func (l *IPv6) ToBytes() ([]byte, error) { return h, nil } +// nextIPv6PayloadParser finds the corresponding parser for nextHeader. +func nextIPv6PayloadParser(nextHeader uint8) layerParser { + switch tcpip.TransportProtocolNumber(nextHeader) { + case header.TCPProtocolNumber: + return parseTCP + case header.UDPProtocolNumber: + return parseUDP + case header.ICMPv6ProtocolNumber: + return parseICMPv6 + } + switch header.IPv6ExtensionHeaderIdentifier(nextHeader) { + case header.IPv6HopByHopOptionsExtHdrIdentifier: + return parseIPv6HopByHopOptionsExtHdr + case header.IPv6DestinationOptionsExtHdrIdentifier: + return parseIPv6DestinationOptionsExtHdr + case header.IPv6FragmentExtHdrIdentifier: + return parseIPv6FragmentExtHdr + } + return parsePayload +} + // parseIPv6 parses the bytes assuming that they start with an ipv6 header and // continues parsing further encapsulations. func parseIPv6(b []byte) (Layer, layerParser) { @@ -509,18 +525,7 @@ func parseIPv6(b []byte) (Layer, layerParser) { SrcAddr: Address(h.SourceAddress()), DstAddr: Address(h.DestinationAddress()), } - var nextParser layerParser - switch h.TransportProtocol() { - case header.TCPProtocolNumber: - nextParser = parseTCP - case header.UDPProtocolNumber: - nextParser = parseUDP - case header.ICMPv6ProtocolNumber: - nextParser = parseICMPv6 - default: - // Assume that the rest is a payload. - nextParser = parsePayload - } + nextParser := nextIPv6PayloadParser(h.NextHeader()) return &ipv6, nextParser } @@ -538,13 +543,241 @@ func (l *IPv6) merge(other Layer) error { return mergeLayer(l, other) } +// IPv6HopByHopOptionsExtHdr can construct and match an IPv6HopByHopOptions +// Extension Header. +type IPv6HopByHopOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// IPv6DestinationOptionsExtHdr can construct and match an IPv6DestinationOptions +// Extension Header. +type IPv6DestinationOptionsExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + Options []byte +} + +// IPv6FragmentExtHdr can construct and match an IPv6 Fragment Extension Header. +type IPv6FragmentExtHdr struct { + LayerBase + NextHeader *header.IPv6ExtensionHeaderIdentifier + FragmentOffset *uint16 + MoreFragments *bool + Identification *uint32 +} + +// nextHeaderByLayer finds the correct next header protocol value for layer l. +func nextHeaderByLayer(l Layer) (uint8, error) { + if l == nil { + return uint8(header.IPv6NoNextHeaderIdentifier), nil + } + switch l.(type) { + case *TCP: + return uint8(header.TCPProtocolNumber), nil + case *UDP: + return uint8(header.UDPProtocolNumber), nil + case *ICMPv6: + return uint8(header.ICMPv6ProtocolNumber), nil + case *Payload: + return uint8(header.IPv6NoNextHeaderIdentifier), nil + case *IPv6HopByHopOptionsExtHdr: + return uint8(header.IPv6HopByHopOptionsExtHdrIdentifier), nil + case *IPv6DestinationOptionsExtHdr: + return uint8(header.IPv6DestinationOptionsExtHdrIdentifier), nil + case *IPv6FragmentExtHdr: + return uint8(header.IPv6FragmentExtHdrIdentifier), nil + default: + // TODO(b/161005083): Support more protocols as needed. + return 0, fmt.Errorf("failed to deduce the IPv6 header's next protocol: %T", l) + } +} + +// ipv6OptionsExtHdrToBytes serializes an options extension header into bytes. +func ipv6OptionsExtHdrToBytes(nextHeader *header.IPv6ExtensionHeaderIdentifier, nextLayer Layer, options []byte) ([]byte, error) { + length := len(options) + 2 + if length%8 != 0 { + return nil, fmt.Errorf("IPv6 extension headers must be a multiple of 8 octets long, but the length given: %d, options: %s", length, hex.Dump(options)) + } + bytes := make([]byte, length) + if nextHeader != nil { + bytes[0] = byte(*nextHeader) + } else { + nh, err := nextHeaderByLayer(nextLayer) + if err != nil { + return nil, err + } + bytes[0] = nh + } + // ExtHdrLen field is the length of the extension header + // in 8-octet unit, ignoring the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + bytes[1] = uint8((length - 8) / 8) + copy(bytes[2:], options) + return bytes, nil +} + +// IPv6ExtHdrIdent is a helper routine that allocates a new +// header.IPv6ExtensionHeaderIdentifier value to store v and returns a pointer +// to it. +func IPv6ExtHdrIdent(id header.IPv6ExtensionHeaderIdentifier) *header.IPv6ExtensionHeaderIdentifier { + return &id +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6HopByHopOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6DestinationOptionsExtHdr) ToBytes() ([]byte, error) { + return ipv6OptionsExtHdrToBytes(l.NextHeader, l.next(), l.Options) +} + +// ToBytes implements Layer.ToBytes. +func (l *IPv6FragmentExtHdr) ToBytes() ([]byte, error) { + var offset, mflag uint16 + var ident uint32 + bytes := make([]byte, header.IPv6FragmentExtHdrLength) + if l.NextHeader != nil { + bytes[0] = byte(*l.NextHeader) + } else { + nh, err := nextHeaderByLayer(l.next()) + if err != nil { + return nil, err + } + bytes[0] = nh + } + bytes[1] = 0 // reserved + if l.MoreFragments != nil && *l.MoreFragments { + mflag = 1 + } + if l.FragmentOffset != nil { + offset = *l.FragmentOffset + } + if l.Identification != nil { + ident = *l.Identification + } + offsetAndMflag := offset<<3 | mflag + binary.BigEndian.PutUint16(bytes[2:], offsetAndMflag) + binary.BigEndian.PutUint32(bytes[4:], ident) + + return bytes, nil +} + +// parseIPv6ExtHdr parses an IPv6 extension header and returns the NextHeader +// field, the rest of the payload and a parser function for the corresponding +// next extension header. +func parseIPv6ExtHdr(b []byte) (header.IPv6ExtensionHeaderIdentifier, []byte, layerParser) { + nextHeader := b[0] + // For HopByHop and Destination options extension headers, + // This field is the length of the extension header in + // 8-octet units, not including the first 8 octets. + // https://tools.ietf.org/html/rfc2460#section-4.3 + // https://tools.ietf.org/html/rfc2460#section-4.6 + length := b[1]*8 + 8 + data := b[2:length] + nextParser := nextIPv6PayloadParser(nextHeader) + return header.IPv6ExtensionHeaderIdentifier(nextHeader), data, nextParser +} + +// parseIPv6HopByHopOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 HopByHop Options Extension Header. +func parseIPv6HopByHopOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6HopByHopOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +// parseIPv6DestinationOptionsExtHdr parses the bytes assuming that they start +// with an IPv6 Destination Options Extension Header. +func parseIPv6DestinationOptionsExtHdr(b []byte) (Layer, layerParser) { + nextHeader, options, nextParser := parseIPv6ExtHdr(b) + return &IPv6DestinationOptionsExtHdr{NextHeader: &nextHeader, Options: options}, nextParser +} + +// Bool is a helper routine that allocates a new +// bool value to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// parseIPv6FragmentExtHdr parses the bytes assuming that they start +// with an IPv6 Fragment Extension Header. +func parseIPv6FragmentExtHdr(b []byte) (Layer, layerParser) { + nextHeader := b[0] + var extHdr header.IPv6FragmentExtHdr + copy(extHdr[:], b[2:]) + return &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(nextHeader)), + FragmentOffset: Uint16(extHdr.FragmentOffset()), + MoreFragments: Bool(extHdr.More()), + Identification: Uint32(extHdr.ID()), + }, nextIPv6PayloadParser(nextHeader) +} + +func (l *IPv6HopByHopOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6HopByHopOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6HopByHopOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6HopByHopOptionsExtHdr) String() string { + return stringLayer(l) +} + +func (l *IPv6DestinationOptionsExtHdr) length() int { + return len(l.Options) + 2 +} + +func (l *IPv6DestinationOptionsExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6DestinationOptionsExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6DestinationOptionsExtHdr) String() string { + return stringLayer(l) +} + +func (*IPv6FragmentExtHdr) length() int { + return header.IPv6FragmentExtHdrLength +} + +func (l *IPv6FragmentExtHdr) match(other Layer) bool { + return equalLayer(l, other) +} + +// merge overrides the values in l with the values from other but only in fields +// where the value is not nil. +func (l *IPv6FragmentExtHdr) merge(other Layer) error { + return mergeLayer(l, other) +} + +func (l *IPv6FragmentExtHdr) String() string { + return stringLayer(l) +} + // ICMPv6 can construct and match an ICMPv6 encapsulation. type ICMPv6 struct { LayerBase - Type *header.ICMPv6Type - Code *byte - Checksum *uint16 - NDPPayload []byte + Type *header.ICMPv6Type + Code *byte + Checksum *uint16 + Payload []byte } func (l *ICMPv6) String() string { @@ -555,7 +788,7 @@ func (l *ICMPv6) String() string { // ToBytes implements Layer.ToBytes. func (l *ICMPv6) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv6HeaderSize+len(l.NDPPayload)) + b := make([]byte, header.ICMPv6HeaderSize+len(l.Payload)) h := header.ICMPv6(b) if l.Type != nil { h.SetType(*l.Type) @@ -563,12 +796,23 @@ func (l *ICMPv6) ToBytes() ([]byte, error) { if l.Code != nil { h.SetCode(*l.Code) } - copy(h.NDPPayload(), l.NDPPayload) + copy(h.NDPPayload(), l.Payload) if l.Checksum != nil { h.SetChecksum(*l.Checksum) } else { - ipv6 := l.Prev().(*IPv6) - h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, buffer.VectorisedView{})) + // It is possible that the ICMPv6 header does not follow the IPv6 header + // immediately, there could be one or more extension headers in between. + // We need to search forward to find the IPv6 header. + for prev := l.Prev(); prev != nil; prev = prev.Prev() { + if ipv6, ok := prev.(*IPv6); ok { + payload, err := payload(l) + if err != nil { + return nil, err + } + h.SetChecksum(header.ICMPv6Checksum(h, *ipv6.SrcAddr, *ipv6.DstAddr, payload)) + break + } + } } return h, nil } @@ -589,10 +833,10 @@ func Byte(v byte) *byte { func parseICMPv6(b []byte) (Layer, layerParser) { h := header.ICMPv6(b) icmpv6 := ICMPv6{ - Type: ICMPv6Type(h.Type()), - Code: Byte(h.Code()), - Checksum: Uint16(h.Checksum()), - NDPPayload: h.NDPPayload(), + Type: ICMPv6Type(h.Type()), + Code: Byte(h.Code()), + Checksum: Uint16(h.Checksum()), + Payload: h.NDPPayload(), } return &icmpv6, nil } @@ -602,7 +846,7 @@ func (l *ICMPv6) match(other Layer) bool { } func (l *ICMPv6) length() int { - return header.ICMPv6HeaderSize + len(l.NDPPayload) + return header.ICMPv6HeaderSize + len(l.Payload) } // merge overrides the values in l with the values from other but only in fields @@ -768,12 +1012,14 @@ func payload(l Layer) (buffer.VectorisedView, error) { func layerChecksum(l Layer, protoNumber tcpip.TransportProtocolNumber) (uint16, error) { totalLength := uint16(totalLength(l)) var xsum uint16 - switch s := l.Prev().(type) { + switch p := l.Prev().(type) { case *IPv4: - xsum = header.PseudoHeaderChecksum(protoNumber, *s.SrcAddr, *s.DstAddr, totalLength) + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) + case *IPv6: + xsum = header.PseudoHeaderChecksum(protoNumber, *p.SrcAddr, *p.DstAddr, totalLength) default: - // TODO(b/150301488): Support more protocols, like IPv6. - return 0, fmt.Errorf("can't get src and dst addr from previous layer: %#v", s) + // TODO(b/161246171): Support more protocols. + return 0, fmt.Errorf("checksum for protocol %d is not supported when previous layer is %T", protoNumber, p) } payloadBytes, err := payload(l) if err != nil { diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index c7f00e70d..a2a763034 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -505,3 +505,224 @@ func TestTCPOptions(t *testing.T) { }) } } + +func TestIPv6ExtHdrOptions(t *testing.T) { + for _, tt := range []struct { + description string + wantBytes []byte + wantLayers Layers + }{ + { + description: "IPv6/HopByHop", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/HopByHop/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3b, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/HopByHop/Destination/ICMPv6", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x18, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x3c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Destination Options + 0x3a, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // ICMPv6 Param Problem + 0x04, 0x00, 0x5f, 0x98, 0x00, 0x00, 0x00, 0x06, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6DestinationOptionsExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &ICMPv6{ + Type: ICMPv6Type(header.ICMPv6ParamProblem), + Code: Byte(0), + Checksum: Uint16(0x5f98), + Payload: []byte{0x00, 0x00, 0x00, 0x06}, + }, + }, + }, + { + description: "IPv6/HopByHop/Fragment", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // HopByHop Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x20, 0x00, 0x00, 0x00, 0x2a, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6HopByHopOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(false), + Identification: Uint32(42), + }, + &Payload{ + Bytes: nil, + }, + }, + }, + { + description: "IPv6/DestOpt/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x1b, 0x3c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Destination Options + 0x2c, 0x00, 0x05, 0x02, 0x00, 0x00, 0x01, 0x00, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6DestinationOptionsExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6FragmentExtHdrIdentifier), + Options: []byte{0x05, 0x02, 0x00, 0x00, 0x01, 0x00}, + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + { + description: "IPv6/Fragment/Payload", + wantBytes: []byte{ + // IPv6 Header + 0x60, 0x00, 0x00, 0x00, 0x00, 0x13, 0x2c, 0x40, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef, + // Fragment ExtHdr + 0x3b, 0x00, 0x03, 0x21, 0x00, 0x00, 0x00, 0x2a, + // Sample Data + 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x44, 0x61, 0x74, 0x61, + }, + wantLayers: []Layer{ + &IPv6{ + SrcAddr: Address(tcpip.Address(net.ParseIP("::1"))), + DstAddr: Address(tcpip.Address(net.ParseIP("fe80::dead:beef"))), + }, + &IPv6FragmentExtHdr{ + NextHeader: IPv6ExtHdrIdent(header.IPv6NoNextHeaderIdentifier), + FragmentOffset: Uint16(100), + MoreFragments: Bool(true), + Identification: Uint32(42), + }, + &Payload{ + Bytes: []byte("Sample Data"), + }, + }, + }, + } { + t.Run(tt.description, func(t *testing.T) { + layers := parse(parseIPv6, tt.wantBytes) + if !layers.match(tt.wantLayers) { + t.Fatalf("match failed with diff: %s", layers.diff(tt.wantLayers)) + } + // Make sure we can generate correct next header values and checksums + for _, layer := range layers { + switch layer := layer.(type) { + case *IPv6HopByHopOptionsExtHdr: + layer.NextHeader = nil + case *IPv6DestinationOptionsExtHdr: + layer.NextHeader = nil + case *IPv6FragmentExtHdr: + layer.NextHeader = nil + case *ICMPv6: + layer.Checksum = nil + } + } + gotBytes, err := layers.ToBytes() + if err != nil { + t.Fatalf("ToBytes() failed on %s: %s", &layers, err) + } + if !bytes.Equal(tt.wantBytes, gotBytes) { + t.Fatalf("mismatching bytes, gotBytes: %x, wantBytes: %x", gotBytes, tt.wantBytes) + } + }) + } +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 278229b7e..57e822725 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -28,7 +28,6 @@ import ( // Sniffer can sniff raw packets on the wire. type Sniffer struct { - t *testing.T fd int } @@ -40,6 +39,8 @@ func htons(x uint16) uint16 { // NewSniffer creates a Sniffer connected to *device. func NewSniffer(t *testing.T) (Sniffer, error) { + t.Helper() + snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL))) if err != nil { return Sniffer{}, err @@ -51,7 +52,6 @@ func NewSniffer(t *testing.T) (Sniffer, error) { t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err) } return Sniffer{ - t: t, fd: snifferFd, }, nil } @@ -61,7 +61,9 @@ func NewSniffer(t *testing.T) (Sniffer, error) { const maxReadSize int = 65536 // Recv tries to read one frame until the timeout is up. -func (s *Sniffer) Recv(timeout time.Duration) []byte { +func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte { + t.Helper() + deadline := time.Now().Add(timeout) for { timeout = deadline.Sub(time.Now()) @@ -75,7 +77,7 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { } if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { - s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) + t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err) } buf := make([]byte, maxReadSize) @@ -85,10 +87,10 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { continue } if err != nil { - s.t.Fatalf("can't read: %s", err) + t.Fatalf("can't read: %s", err) } if nread > maxReadSize { - s.t.Fatalf("received a truncated frame of %d bytes", nread) + t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize) } return buf[:nread] } @@ -96,14 +98,16 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte { // Drain drains the Sniffer's socket receive buffer by receiving until there's // nothing else to receive. -func (s *Sniffer) Drain() { - s.t.Helper() +func (s *Sniffer) Drain(t *testing.T) { + t.Helper() + flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0) if err != nil { - s.t.Fatalf("failed to get sniffer socket fd flags: %s", err) + t.Fatalf("failed to get sniffer socket fd flags: %s", err) } - if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil { - s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err) + nonBlockingFlags := flags | unix.O_NONBLOCK + if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil { + t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err) } for { buf := make([]byte, maxReadSize) @@ -113,7 +117,7 @@ func (s *Sniffer) Drain() { } } if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil { - s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err) + t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err) } } @@ -128,12 +132,13 @@ func (s *Sniffer) close() error { // Injector can inject raw frames. type Injector struct { - t *testing.T fd int } // NewInjector creates a new injector on *device. func NewInjector(t *testing.T) (Injector, error) { + t.Helper() + ifInfo, err := net.InterfaceByName(Device) if err != nil { return Injector{}, err @@ -156,15 +161,20 @@ func NewInjector(t *testing.T) (Injector, error) { return Injector{}, err } return Injector{ - t: t, fd: injectFd, }, nil } // Send a raw frame. -func (i *Injector) Send(b []byte) { - if _, err := unix.Write(i.fd, b); err != nil { - i.t.Fatalf("can't write: %s of len %d", err, len(b)) +func (i *Injector) Send(t *testing.T, b []byte) { + t.Helper() + + n, err := unix.Write(i.fd, b) + if err != nil { + t.Fatalf("can't write bytes of len %d: %s", len(b), err) + } + if n != len(b) { + t.Fatalf("got %d bytes written, want %d", n, len(b)) } } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index d64f32a5b..242464e3a 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -31,23 +31,37 @@ var ( DUTType = "" // Device is the local device on the test network. Device = "" + // LocalIPv4 is the local IPv4 address on the test network. LocalIPv4 = "" + // RemoteIPv4 is the DUT's IPv4 address on the test network. + RemoteIPv4 = "" + // IPv4PrefixLength is the network prefix length of the IPv4 test network. + IPv4PrefixLength = 0 + // LocalIPv6 is the local IPv6 address on the test network. LocalIPv6 = "" + // RemoteIPv6 is the DUT's IPv6 address on the test network. + RemoteIPv6 = "" + + // LocalInterfaceID is the ID of the local interface on the test network. + LocalInterfaceID uint32 + // RemoteInterfaceID is the ID of the remote interface on the test network. + // + // Not using uint32 because package flag does not support uint32. + RemoteInterfaceID uint64 + // LocalMAC is the local MAC address on the test network. LocalMAC = "" + // RemoteMAC is the DUT's MAC address on the test network. + RemoteMAC = "" + // POSIXServerIP is the POSIX server's IP address on the control network. POSIXServerIP = "" // POSIXServerPort is the UDP port the POSIX server is bound to on the // control network. POSIXServerPort = 40000 - // RemoteIPv4 is the DUT's IPv4 address on the test network. - RemoteIPv4 = "" - // RemoteIPv6 is the DUT's IPv6 address on the test network. - RemoteIPv6 = "" - // RemoteMAC is the DUT's MAC address on the test network. - RemoteMAC = "" + // RPCKeepalive is the gRPC keepalive. RPCKeepalive = 10 * time.Second // RPCTimeout is the gRPC timeout. @@ -68,6 +82,7 @@ func RegisterFlags(fs *flag.FlagSet) { fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") fs.StringVar(&Device, "device", Device, "local device for test packets") fs.StringVar(&DUTType, "dut_type", DUTType, "type of device under test") + fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") } // genPseudoFlags populates flag-like global config based on real flags. @@ -90,6 +105,13 @@ func genPseudoFlags() error { LocalMAC = deviceInfo.MAC.String() LocalIPv6 = deviceInfo.IPv6Addr.String() + LocalInterfaceID = deviceInfo.ID + + if deviceInfo.IPv4Net != nil { + IPv4PrefixLength, _ = deviceInfo.IPv4Net.Mask.Size() + } else { + IPv4PrefixLength, _ = net.ParseIP(LocalIPv4).DefaultMask().Size() + } return nil } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 09e91b832..0c2a05380 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -18,8 +18,6 @@ packetimpact_go_test( packetimpact_go_test( name = "ipv4_id_uniqueness", srcs = ["ipv4_id_uniqueness_test.go"], - # TODO(b/157506701) Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/abi/linux", "//pkg/tcpip/header", @@ -29,14 +27,37 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "udp_recv_multicast", - srcs = ["udp_recv_multicast_test.go"], + name = "udp_discard_mcast_source_addr", + srcs = ["udp_discard_mcast_source_addr_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_recv_mcast_bcast", + srcs = ["udp_recv_mcast_bcast_test.go"], # TODO(b/152813495): Fix netstack then remove the line below. expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "udp_any_addr_recv_unicast", + srcs = ["udp_any_addr_recv_unicast_test.go"], + deps = [ + "//pkg/tcpip", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -183,8 +204,6 @@ packetimpact_go_test( packetimpact_go_test( name = "tcp_queue_receive_in_syn_sent", srcs = ["tcp_queue_receive_in_syn_sent_test.go"], - # TODO(b/157658105): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", @@ -213,8 +232,8 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "tcp_splitseg_mss", - srcs = ["tcp_splitseg_mss_test.go"], + name = "tcp_network_unreachable", + srcs = ["tcp_network_unreachable_test.go"], deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", @@ -256,10 +275,38 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "ipv6_unknown_options_action", + srcs = ["ipv6_unknown_options_action_test.go"], + # TODO(b/159928940): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "ipv6_fragment_reassembly", + srcs = ["ipv6_fragment_reassembly_test.go"], + # TODO(b/160919104): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "udp_send_recv_dgram", srcs = ["udp_send_recv_dgram_test.go"], deps = [ "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 407565078..a61054c2c 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -39,34 +39,34 @@ func TestFinWait2Timeout(t *testing.T) { t.Run(tt.description, func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() + defer conn.Close(t) + conn.Connect(t) - acceptFd, _ := dut.Accept(listenFd) + acceptFd, _ := dut.Accept(t, listenFd) if tt.linger2 { tv := unix.Timeval{Sec: 1, Usec: 0} - dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) + dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv) } - dut.Close(acceptFd) + dut.Close(t, acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) time.Sleep(5 * time.Second) - conn.Drain() + conn.Drain(t) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) if tt.linger2 { - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) } } diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go index 4d1d9a7f5..2d59d552d 100644 --- a/test/packetimpact/tests/icmpv6_param_problem_test.go +++ b/test/packetimpact/tests/icmpv6_param_problem_test.go @@ -34,19 +34,19 @@ func TestICMPv6ParamProblemTest(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close() + defer conn.Close(t) ipv6 := testbench.IPv6{ // 254 is reserved and used for experimentation and testing. This should // cause an error. NextHeader: testbench.Uint8(254), } icmpv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - NDPPayload: []byte("hello world"), + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Payload: []byte("hello world"), } - toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6) - (*testbench.Connection)(&conn).SendFrame(toSend) + toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6) + (*testbench.Connection)(&conn).SendFrame(t, toSend) // Build the expected ICMPv6 payload, which includes an index to the // problematic byte and also the problematic packet as described in @@ -62,8 +62,8 @@ func TestICMPv6ParamProblemTest(t *testing.T) { binary.BigEndian.PutUint32(b, header.IPv6NextHeaderOffset) expectedPayload = append(b, expectedPayload...) expectedICMPv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), - NDPPayload: expectedPayload, + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Payload: expectedPayload, } paramProblem := testbench.Layers{ @@ -72,7 +72,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) { &expectedICMPv6, } timeout := time.Second - if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil { + if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil { t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err) } } diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go index 70f6df5e0..cf881418c 100644 --- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -31,8 +31,8 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { - layers, err := conn.ExpectData(expect, expectPayload, time.Second) +func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) { + layers, err := conn.ExpectData(t, expect, expectPayload, time.Second) if err != nil { return 0, fmt.Errorf("failed to receive TCP segment: %s", err) } @@ -69,17 +69,17 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - remoteFD, _ := dut.Accept(listenFD) - defer dut.Close(remoteFD) + conn.Connect(t) + remoteFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, remoteFD) - dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) // TODO(b/129291778) The following socket option clears the DF bit on // IP packets sent over the socket, and is currently not supported by @@ -87,30 +87,30 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { // socket option being not supported does not affect the operation of // this test. Once the socket option is supported, the following call // can be changed to simply assert success. - ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) + ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT) if ret == -1 && errno != unix.ENOTSUP { t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno) } samplePayload := &testbench.Payload{Bytes: tc.payload} - dut.Send(remoteFD, tc.payload, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, remoteFD, tc.payload, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err) } // Let the DUT estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(remoteFD, tc.payload, 0) - expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))} - originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + dut.Send(t, remoteFD, tc.payload, 0) + expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))} + originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive TCP segment: %s", err) } - retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload) + retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload) if err != nil { t.Fatalf("failed to receive retransmitted TCP segment: %s", err) } diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go new file mode 100644 index 000000000..b5f94ad4b --- /dev/null +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_fragment_reassembly_test + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +const ( + // The payload length for the first fragment we send. This number + // is a multiple of 8 near 750 (half of 1500). + firstPayloadLength = 752 + // The ID field for our outgoing fragments. + fragmentID = 1 + // A node must be able to accept a fragmented packet that, + // after reassembly, is as large as 1500 octets. + reassemblyCap = 1500 +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestIPv6FragmentReassembly(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + + firstPayloadToSend := make([]byte, firstPayloadLength) + for i := range firstPayloadToSend { + firstPayloadToSend[i] = 'A' + } + + secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize + secondPayloadToSend := firstPayloadToSend[:secondPayloadLength] + + icmpv6EchoPayload := make([]byte, 4) + binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0) + binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0) + icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...) + + lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) + icmpv6 := testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.Byte(0), + Payload: icmpv6EchoPayload, + } + icmpv6Bytes, err := icmpv6.ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + cksum := header.ICMPv6Checksum( + header.ICMPv6(icmpv6Bytes), + lIP, + rIP, + buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), + ) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), + Code: testbench.Byte(0), + Payload: icmpv6EchoPayload, + Checksum: &cksum, + }) + + icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) + + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), + MoreFragments: testbench.Bool(false), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.Payload{ + Bytes: secondPayloadToSend, + }) + + gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + FragmentOffset: testbench.Uint16(0), + MoreFragments: testbench.Bool(true), + }, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), + Code: testbench.Byte(0), + }, + }, time.Second) + if err != nil { + t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err) + } + + id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification + gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6: %s", err) + } + icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:] + receivedLen := len(icmpPayload) + wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen + wantFirstPayload := make([]byte, receivedLen) + for i := range wantFirstPayload { + wantFirstPayload[i] = 'A' + } + wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen] + if !bytes.Equal(icmpPayload, wantFirstPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(icmpPayload), + hex.Dump(wantFirstPayload)) + } + + gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)), + MoreFragments: testbench.Bool(false), + Identification: &id, + }, + &testbench.ICMPv6{}, + }, time.Second) + if err != nil { + t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err) + } + secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes() + if err != nil { + t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err) + } + if !bytes.Equal(secondPayload, wantSecondPayload) { + t.Fatalf("received unexpected payload, got: %s, want: %s", + hex.Dump(secondPayload), + hex.Dump(wantSecondPayload)) + } +} diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go new file mode 100644 index 000000000..d7d63cbd2 --- /dev/null +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -0,0 +1,187 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6_unknown_options_action_test + +import ( + "encoding/binary" + "flag" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6HopByHopOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func mkDestinationOptionsExtHdr(optType byte) testbench.Layer { + return &testbench.IPv6DestinationOptionsExtHdr{ + Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00}, + } +} + +func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte { + return byte(action << 6) +} + +func TestIPv6UnknownOptionAction(t *testing.T) { + for _, tt := range []struct { + description string + mkExtHdr func(optType byte) testbench.Layer + action header.IPv6OptionUnknownAction + multicastDst bool + wantICMPv6 bool + }{ + { + description: "0b00/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/hbh", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/hbh/unicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/hbh/multicast", + mkExtHdr: mkHopByHopOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + { + description: "0b00/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionSkip, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b01/destination", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscard, + multicastDst: false, + wantICMPv6: false, + }, + { + description: "0b10/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b10/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMP, + multicastDst: true, + wantICMPv6: true, + }, + { + description: "0b11/destination/unicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: false, + wantICMPv6: true, + }, + { + description: "0b11/destination/multicast", + mkExtHdr: mkDestinationOptionsExtHdr, + action: header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest, + multicastDst: true, + wantICMPv6: false, + }, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + conn := (*testbench.Connection)(&ipv6Conn) + defer ipv6Conn.Close(t) + + outgoingOverride := testbench.Layers{} + if tt.multicastDst { + outgoingOverride = testbench.Layers{&testbench.IPv6{ + DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))), + }} + } + + outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action))) + conn.SendFrame(t, outgoing) + ipv6Sent := outgoing[1:] + invokingPacket, err := ipv6Sent.ToBytes() + if err != nil { + t.Fatalf("failed to serialize the outgoing packet: %s", err) + } + icmpv6Payload := make([]byte, 4) + // The pointer in the ICMPv6 parameter problem message should point to + // the option type of the unknown option. In our test case, it is the + // first option in the extension header whose option type is 2 bytes + // after the IPv6 header (after NextHeader and ExtHdrLen). + binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2) + icmpv6Payload = append(icmpv6Payload, invokingPacket...) + gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), + Code: testbench.Byte(2), + Payload: icmpv6Payload, + }, + }, time.Second) + if tt.wantICMPv6 && err != nil { + t.Fatalf("expected ICMPv6 Parameter Problem but got none: %s", err) + } + if !tt.wantICMPv6 && gotICMPv6 != nil { + t.Fatalf("expected no ICMPv6 Parameter Problem but got one: %s", gotICMPv6) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go index 6e7ff41d7..e6a96f214 100644 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go @@ -33,39 +33,39 @@ func init() { func TestCloseWaitAck(t *testing.T) { for _, tt := range []struct { description string - makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP seqNumOffset seqnum.Size expectAck bool }{ - {"OTW", GenerateOTWSeqSegment, 0, false}, - {"OTW", GenerateOTWSeqSegment, 1, true}, - {"OTW", GenerateOTWSeqSegment, 2, true}, - {"ACK", GenerateUnaccACKSegment, 0, false}, - {"ACK", GenerateUnaccACKSegment, 1, true}, - {"ACK", GenerateUnaccACKSegment, 2, true}, + {"OTW", generateOTWSeqSegment, 0, false}, + {"OTW", generateOTWSeqSegment, 1, true}, + {"OTW", generateOTWSeqSegment, 2, true}, + {"ACK", generateUnaccACKSegment, 0, false}, + {"ACK", generateUnaccACKSegment, 1, true}, + {"ACK", generateUnaccACKSegment, 2, true}, } { t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) // Send a FIN to DUT to intiate the active close - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) } windowSize := seqnum.Size(*gotTCP.WindowSize) // Send a segment with OTW Seq / unacc ACK and expect an ACK back - conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) - gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if tt.expectAck && err != nil { t.Fatalf("expected an ack but got none: %s", err) } @@ -74,35 +74,36 @@ func TestCloseWaitAck(t *testing.T) { } // Now let's verify DUT is indeed in CLOSE_WAIT - dut.Close(acceptFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + dut.Close(t, acceptFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { t.Fatalf("expected DUT to send a FIN: %s", err) } // Ack the FIN from DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // Send some extra data to DUT - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected DUT to send an RST: %s", err) } }) } } -// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the -// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK -// is expected from the receiver. -func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum().Add(windowSize) +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} } -// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated -// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is -// expected from the receiver. -func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum() +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} } diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go index fb8f48629..8feea4a82 100644 --- a/test/packetimpact/tests/tcp_cork_mss_test.go +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -32,53 +32,53 @@ func init() { func TestTCPCorkMSS(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) const mss = uint32(header.TCPDefaultMSS) options := make([]byte, header.TCPOptionMSSLength) header.EncodeMSSOption(mss, options) - conn.ConnectWithOptions(options) + conn.ConnectWithOptions(t, options) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) + dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1) // Let the dut application send 2 small segments to be held up and coalesced // until the application sends a larger segment to fill up to > MSS. sampleData := []byte("Sample Data") - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedData := sampleData expectedData = append(expectedData, sampleData...) largeData := make([]byte, mss+1) expectedData = append(expectedData, largeData...) - dut.Send(acceptFD, largeData, 0) + dut.Send(t, acceptFD, largeData, 0) // Expect the segments to be coalesced and sent and capped to MSS. expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the coalesced segment to be split and transmitted. expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Check for segments to *not* be held up because of TCP_CORK when // the current send window is less than MSS. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) - dut.Send(acceptFD, sampleData, 0) - dut.Send(acceptFD, sampleData, 0) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + dut.Send(t, acceptFD, sampleData, 0) + dut.Send(t, acceptFD, sampleData, 0) expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) } diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go index 652b530d0..22937d92f 100644 --- a/test/packetimpact/tests/tcp_handshake_window_size_test.go +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -33,14 +33,14 @@ func init() { func TestTCPHandshakeWindowSize(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Start handshake with zero window size. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK: %s", err) } // Update the advertised window size to a non-zero value with the ACK that @@ -48,10 +48,10 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // // Set the window size with MSB set and expect the dut to treat it as // an unsigned value. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) - acceptFd, _ := dut.Accept(listenFD) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFd) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} @@ -59,8 +59,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // Since we advertised a zero window followed by a non-zero window, // expect the dut to honor the recently advertised non-zero window // and actually send out the data instead of probing for zero window. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go new file mode 100644 index 000000000..900352fa1 --- /dev/null +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -0,0 +1,139 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_synsent_reset_test + +import ( + "context" + "flag" + "net" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTCPSynSentUnreachable verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + port := uint16(9001) + conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet4{Port: int(port)} + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv4) + if !ok { + t.Fatalf("expected %s to be IPv4", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4HostUnreachable)} + layers = append(layers, &icmpv4, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) { + t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err) + } +} + +// TestTCPSynSentUnreachable6 verifies that TCP connections fail immediately when +// an ICMP destination unreachable message is sent in response to the inital +// SYN. +func TestTCPSynSentUnreachable6(t *testing.T) { + // Create the DUT and connection. + dut := testbench.NewDUT(t) + defer dut.TearDown() + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6)) + conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort}) + defer conn.Close(t) + + // Bring the DUT to SYN-SENT state with a non-blocking connect. + ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout) + defer cancel() + sa := unix.SockaddrInet6{ + Port: int(conn.SrcPort()), + ZoneId: uint32(testbench.RemoteInterfaceID), + } + copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16()) + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err) + } + + // Get the SYN. + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + if err != nil { + t.Fatalf("expected SYN: %s", err) + } + + // Send a host unreachable message. + rawConn := (*testbench.Connection)(&conn) + layers := rawConn.CreateFrame(t, nil) + layers = layers[:len(layers)-1] + const ipLayer = 1 + const tcpLayer = ipLayer + 1 + ip, ok := tcpLayers[ipLayer].(*testbench.IPv6) + if !ok { + t.Fatalf("expected %s to be IPv6", tcpLayers[ipLayer]) + } + tcp, ok := tcpLayers[tcpLayer].(*testbench.TCP) + if !ok { + t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) + } + var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{ + Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable), + Code: testbench.Uint8(header.ICMPv6NetworkUnreachable), + // Per RFC 4443 3.1, the payload contains 4 zeroed bytes. + Payload: []byte{0, 0, 0, 0}, + } + layers = append(layers, &icmpv6, ip, tcp) + rawConn.SendFrameStateless(t, layers) + + if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) { + t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err) + } +} diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index b9b3e91d3..82b7a85ff 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -31,12 +31,12 @@ func init() { func TestTcpNoAcceptCloseReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - conn.Connect() - defer conn.Close() - dut.Close(listenFd) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + conn.Connect(t) + defer conn.Close(t) + dut.Close(t, listenFd) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) } } diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index ad8c74234..08f759f7c 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -63,25 +63,25 @@ func TestTCPOutsideTheWindow(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) - windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset - conn.Drain() + windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset + conn.Drain(t) // Ignore whatever incrementing that this out-of-order packet might cause // to the AckNum. - localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{ + localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{ Flags: testbench.Uint8(tt.tcpFlags), - SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))), + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), }, tt.payload...) timeout := 3 * time.Second - gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) if tt.expectACK && err != nil { t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) } diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go index 55db4ece6..37f3b56dd 100644 --- a/test/packetimpact/tests/tcp_paws_mechanism_test.go +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -32,15 +32,15 @@ func init() { func TestPAWSMechanism(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) options := make([]byte, header.TCPOptionTSLength) header.EncodeTSOption(currentTS(), 0, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) - synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } @@ -50,9 +50,9 @@ func TestPAWSMechanism(t *testing.T) { } tsecr := parsedSynOpts.TSVal header.EncodeTSOption(currentTS(), tsecr, options) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) sampleData := []byte("Sample Data") sentTSVal := currentTS() @@ -61,9 +61,9 @@ func TestPAWSMechanism(t *testing.T) { // every time we send one, it should not cause any flakiness because timestamps // only need to be non-decreasing. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK but got none: %s", err) } @@ -86,9 +86,9 @@ func TestPAWSMechanism(t *testing.T) { // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness // due to the exact same reasoning discussed above. time.Sleep(3 * time.Millisecond) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) } diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go index b640d8673..d9f3ea0f2 100644 --- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go +++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go @@ -35,53 +35,98 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } +// TestQueueReceiveInSynSent tests receive behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants where the receive is blocked and: +// (1) we complete handshake and send sample data. +// (2) we send a TCP RST. func TestQueueReceiveInSynSent(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Send DATA", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() - socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) - conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) - sampleData := []byte("Sample Data") + sampleData := []byte("Sample Data") - dut.SetNonBlocking(socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) { - t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) - } - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { - t.Fatalf("expected a SYN from DUT, but got none: %s", err) - } + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } - // Issue RECEIVE call in SYN-SENT, this should be queued for process until the connection - // is established. - dut.SetNonBlocking(socket, false) - var wg sync.WaitGroup - defer wg.Wait() - wg.Add(1) - go func() { - defer wg.Done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() - n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0) - if n == -1 { - t.Fatalf("failed to recv on DUT: %s", err) - } - if got := buff[:n]; !bytes.Equal(got, sampleData) { - t.Fatalf("received data don't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData)) - } - }() - - // The following sleep is used to prevent the connection from being established while the - // RPC is in flight. - time.Sleep(time.Second) - - // Bring the connection to Established. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) - if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { - t.Fatalf("expected an ACK from DUT, but got none: %s", err) - } + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking read. + dut.SetNonBlocking(t, socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() - // Send sample data to DUT. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + block.Done() + // Issue RECEIVE call in SYN-SENT, this should be queued for + // process until the connection is established. + n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n == -1 { + t.Errorf("failed to recv on DUT: %s", err) + } + if got := buff[:n]; !bytes.Equal(got, sampleData) { + t.Errorf("received data doesn't match, got:\n%s, want:\n%s", hex.Dump(got), hex.Dump(sampleData)) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint receive. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on Recv. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + + // Send sample payload and expect an ACK. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } } diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go index a5378a9dd..8742819ca 100644 --- a/test/packetimpact/tests/tcp_reordering_test.go +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -32,10 +32,10 @@ func init() { func TestReorderingWindow(t *testing.T) { dut := tb.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Enable SACK. opts := make([]byte, 40) @@ -49,17 +49,17 @@ func TestReorderingWindow(t *testing.T) { const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize optsOff += header.EncodeMSSOption(mss, opts[optsOff:]) - conn.ConnectWithOptions(opts[:optsOff]) + conn.ConnectWithOptions(t, opts[:optsOff]) - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) if tb.DUTType == "linux" { // Linux has changed its handling of reordering, force the old behavior. - dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) + dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) } - pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) + pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) if tb.DUTType == "netstack" { // netstack does not impliment TCP_MAXSEG correctly. Fake it // here. Netstack uses the max SACK size which is 32. The MSS @@ -69,13 +69,13 @@ func TestReorderingWindow(t *testing.T) { payload := make([]byte, pls) - seqNum1 := *conn.RemoteSeqNum() + seqNum1 := *conn.RemoteSeqNum(t) const numPkts = 10 // Send some packets, checking that we receive each. for i, sn := 0, seqNum1; i < numPkts; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -86,7 +86,7 @@ func TestReorderingWindow(t *testing.T) { } } - seqNum2 := *conn.RemoteSeqNum() + seqNum2 := *conn.RemoteSeqNum(t) // SACK packets #2-4. sackBlock := make([]byte, 40) @@ -97,13 +97,13 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(len(payload))), seqNum1.Add(seqnum.Size(4 * len(payload))), }}, sackBlock[sbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // ACK first packet. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))}) // Check for retransmit. - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second) if err != nil { t.Error("Expect for retransmit:", err) } @@ -123,14 +123,14 @@ func TestReorderingWindow(t *testing.T) { seqNum1.Add(seqnum.Size(4 * len(payload))), }}, dsackBlock[dsbOff:]) - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]}) // Send half of the original window of packets, checking that we // received each. for i, sn := 0, seqNum2; i < numPkts/2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -144,8 +144,8 @@ func TestReorderingWindow(t *testing.T) { if tb.DUTType == "netstack" { // The window should now be halved, so we should receive any // more, even if we send them. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } return @@ -153,9 +153,9 @@ func TestReorderingWindow(t *testing.T) { // Linux reduces the window by three. Check that we can receive the rest. for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ { - dut.Send(acceptFd, payload, 0) + dut.Send(t, acceptFd, payload, 0) - gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) + gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second) sn.UpdateForward(seqnum.Size(len(payload))) if err != nil { t.Errorf("Expect #%d: %s", i+1, err) @@ -167,8 +167,8 @@ func TestReorderingWindow(t *testing.T) { } // The window should now be full. - dut.Send(acceptFd, payload, 0) - if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { + dut.Send(t, acceptFd, payload, 0) + if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil { t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got) } } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index 6940eb7fb..072014ff8 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -33,41 +33,41 @@ func init() { func TestRetransmits(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) startRTO := time.Second current := startRTO first := time.Now() - dut.Send(acceptFd, sampleData, 0) - seq := testbench.Uint32(uint32(*conn.RemoteSeqNum())) - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Expect retransmits of the same segment. for i := 0; i < 5; i++ { start := time.Now() - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil { t.Fatalf("expected payload was not received: %s loop %d", err, i) } if i == 0 { diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go index 90ab85419..f91b06ba1 100644 --- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -61,23 +61,23 @@ func TestSendWindowSizesPiggyback(t *testing.T) { t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -86,18 +86,18 @@ func TestSendWindowSizesPiggyback(t *testing.T) { if tt.enqueue { // Enqueue a segment for the dut to transmit. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) } // Send ACK for the previous segment along with data for the dut to // receive and ACK back. Sending this ACK would make room for the dut // to transmit any enqueued segment. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) // Expect the dut to piggyback the ACK for received data along with // the segment enqueued for transmit. expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2} - if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } }) diff --git a/test/packetimpact/tests/tcp_splitseg_mss_test.go b/test/packetimpact/tests/tcp_splitseg_mss_test.go deleted file mode 100644 index 9350d0988..000000000 --- a/test/packetimpact/tests/tcp_splitseg_mss_test.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_splitseg_mss_test - -import ( - "flag" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.RegisterFlags(flag.CommandLine) -} - -// TestTCPSplitSegMSS lets the dut try to send segments larger than MSS. -// It tests if the transmitted segments are capped at MSS and are split. -func TestTCPSplitSegMSS(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) - conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - - const mss = uint32(header.TCPDefaultMSS) - options := make([]byte, header.TCPOptionMSSLength) - header.EncodeMSSOption(mss, options) - conn.ConnectWithOptions(options) - - acceptFD, _ := dut.Accept(listenFD) - defer dut.Close(acceptFD) - - // Let the dut send a segment larger than MSS. - largeData := make([]byte, mss+1) - for i := 0; i < 2; i++ { - dut.Send(acceptFD, largeData, 0) - if i == 0 { - // On Linux, the initial segment goes out beyond MSS and the segment - // split occurs on retransmission. Call ExpectData to wait to - // receive the split segment. - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: largeData[:mss]}, time.Second); err != nil { - t.Fatalf("expected payload was not received: %s", err) - } - } else { - if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: largeData[:mss]}, time.Second); err != nil { - t.Fatalf("expected payload was not received: %s", err) - } - } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: largeData[mss:]}, time.Second); err != nil { - t.Fatalf("expected payload was not received: %s", err) - } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - } -} diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go index 7d5deab01..57d034dd1 100644 --- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -32,21 +32,21 @@ func init() { func TestTCPSynRcvdReset(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Expect dut connection to have transitioned to SYN-RCVD state. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST %s", err) } } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index 6898a2239..eac8eb19d 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -31,17 +31,19 @@ func init() { // dutSynSentState sets up the dut connection in SYN-SENT state. func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { + t.Helper() + dut := tb.NewDUT(t) - clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) + clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4)) port := uint16(9001) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port}) sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4()) // Bring the dut to SYN-SENT state with a non-blocking connect. - dut.Connect(clientFD, &sa) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + dut.Connect(t, clientFD, &sa) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { t.Fatalf("expected SYN\n") } @@ -51,13 +53,13 @@ func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) { // TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition. func TestTCPSynSentReset(t *testing.T) { dut, conn, _, _ := dutSynSentState(t) - defer conn.Close() + defer conn.Close(t) defer dut.TearDown() - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } @@ -67,22 +69,22 @@ func TestTCPSynSentReset(t *testing.T) { func TestTCPSynSentRcvdReset(t *testing.T) { dut, c, remotePort, clientPort := dutSynSentState(t) defer dut.TearDown() - defer c.Close() + defer c.Close(t) conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) // Initiate new SYN connection with the same port pair // (simultaneous open case), expect the dut connection to move to // SYN-RCVD state - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s\n", err) } - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go index 87e45d765..551dc78e7 100644 --- a/test/packetimpact/tests/tcp_user_timeout_test.go +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -16,7 +16,6 @@ package tcp_user_timeout_test import ( "flag" - "fmt" "testing" "time" @@ -29,22 +28,20 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { +func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { sampleData := make([]byte, 100) for i := range sampleData { sampleData[i] = uint8(i) } - conn.Drain() - dut.Send(fd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { - return fmt.Errorf("expected data but got none: %w", err) + conn.Drain(t) + dut.Send(t, fd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + t.Fatalf("expected data but got none: %w", err) } - return nil } -func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error { - dut.Close(fd) - return nil +func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) { + dut.Close(t, fd) } func TestTCPUserTimeout(t *testing.T) { @@ -59,7 +56,7 @@ func TestTCPUserTimeout(t *testing.T) { } { for _, ttf := range []struct { description string - f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error + f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32) }{ {"AfterPayload", sendPayload}, {"AfterFIN", sendFIN}, @@ -68,31 +65,29 @@ func TestTCPUserTimeout(t *testing.T) { // Create a socket, listen, TCP handshake, and accept. dut := testbench.NewDUT(t) defer dut.TearDown() - listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFD) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() - conn.Connect() - acceptFD, _ := dut.Accept(listenFD) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) if tt.userTimeout != 0 { - dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) + dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds())) } - if err := ttf.f(&conn, &dut, acceptFD); err != nil { - t.Fatal(err) - } + ttf.f(t, &conn, &dut, acceptFD) time.Sleep(tt.sendDelay) - conn.Drain() - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Drain(t) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) // If TCP_USER_TIMEOUT was set and the above delay was longer than the // TCP_USER_TIMEOUT then the DUT should send a RST in response to the // testbench's packet. expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout expectTimeout := 5 * time.Second - got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) if expectRST && err != nil { t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) } diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index e78d04756..5b001fbec 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -31,43 +31,43 @@ func init() { func TestWindowShrink(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - dut.Send(acceptFd, sampleData, 0) - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // We close our receiving window here - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - dut.Send(acceptFd, []byte("Sample Data"), 0) + dut.Send(t, acceptFd, []byte("Sample Data"), 0) // Note: There is another kind of zero-window probing which Windows uses (by sending one // new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change // the following lines. - expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1 - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { + expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1 + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index 5ab193181..da93267d6 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeRetransmit(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,15 +63,15 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // of the recorded first zero probe transmission duration. // // Advertize zero receive window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) startProbeDuration := time.Second current := startProbeDuration first := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect the dut to keep the connection alive as long as the remote is // acknowledging the zero-window probes. for i := 0; i < 5; i++ { @@ -79,7 +79,7 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { // Expect zero-window probe with a timeout which is a function of the typical // first retransmission time. The retransmission times is supposed to // exponentially increase. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil { t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i) } if i == 0 { @@ -88,18 +88,17 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { continue } // Check if the probes came at exponentially increasing intervals. - if p := time.Since(start); p < current-startProbeDuration { - t.Fatalf("zero probe came sooner interval %d probe %d\n", p, i) + if got, want := time.Since(start), current-startProbeDuration; got < want { + t.Errorf("got zero probe %d after %s, want >= %s", i, got, want) } // Acknowledge the zero-window probes from the dut. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) current *= 2 } // Advertize non-zero window. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench. - TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go index 649fd5699..44cac42f8 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -33,29 +33,29 @@ func init() { func TestZeroWindowProbe(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} start := time.Now() // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } sendTime := time.Now().Sub(start) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -63,24 +63,24 @@ func TestZeroWindowProbe(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) // Expected ack number of the ACK for the probe. - ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum())) + ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) // Expect there are no zero-window probes sent until there is data to be sent out // from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil { t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err) } start = time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Expect the probe to be sent after some time. Compare against the previous @@ -94,9 +94,9 @@ func TestZeroWindowProbe(t *testing.T) { // and sends out the sample payload after the send window opens. // // Advertize non-zero window to the dut and ack the zero window probe. - conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } @@ -104,9 +104,9 @@ func TestZeroWindowProbe(t *testing.T) { // Check if the dut responds as we do for a similar probe sent to it. // Basically with sequence number to one byte behind the unacknowledged // sequence number. - p := testbench.Uint32(uint32(*conn.LocalSeqNum())) - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + p := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { t.Fatalf("expected a packet with ack number: %d: %s", p, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go index 3c467b14f..09a1c653f 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -33,27 +33,27 @@ func init() { func TestZeroWindowProbeUserTimeout(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(listenFd) + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) - conn.Connect() - acceptFd, _ := dut.Accept(listenFd) - defer dut.Close(acceptFd) + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} // Send and receive sample data to the dut. - dut.Send(acceptFd, sampleData, 0) - if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil { + dut.Send(t, acceptFd, sampleData, 0) + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -61,15 +61,15 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. - probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1)) + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) start := time.Now() // Ask the dut to send out data. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Expect zero-window probe from the dut. - if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil { t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err) } // Record the duration for first probe, the dut sends the zero window probe after @@ -80,19 +80,19 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // when the dut is sending zero-window probes. // // Reduce the retransmit timeout. - dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) + dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) // Advertize zero window again. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Ask the dut to send out data that would trigger zero window probe retransmissions. - dut.Send(acceptFd, sampleData, 0) + dut.Send(t, acceptFd, sampleData, 0) // Wait for the connection to timeout after multiple zero-window probe retransmissions. time.Sleep(8 * startProbeDuration) // Expect the connection to have timed out and closed which would cause the dut // to reply with a RST to the ACK we send. - conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/udp_recv_multicast_test.go b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go index 77a9bfa1d..17f32ef65 100644 --- a/test/packetimpact/tests/udp_recv_multicast_test.go +++ b/test/packetimpact/tests/udp_any_addr_recv_unicast_test.go @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package udp_recv_multicast_test +package udp_any_addr_recv_unicast_test import ( "flag" "net" "testing" + "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/test/packetimpact/testbench" @@ -28,13 +29,23 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestUDPRecvMulticast(t *testing.T) { +func TestAnyRecvUnicastUDP(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, boundFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() - conn.SendIP(testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP("224.0.0.1").To4()))}, testbench.UDP{}) - dut.Recv(boundFD, 100, 0) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(net.ParseIP(testbench.RemoteIPv4).To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } } diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go new file mode 100644 index 000000000..d30177e64 --- /dev/null +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_discard_mcast_source_addr_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +var oneSecond = unix.Timeval{Sec: 1, Usec: 0} + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv4allsys, + net.IPv4allrouter, + net.IPv4(224, 0, 1, 42), + net.IPv4(232, 1, 2, 3), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIP( + t, + testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, + testbench.UDP{}, + ) + + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6)) + defer dut.Close(t, remoteFD) + dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond) + conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, mcastAddr := range []net.IP{ + net.IPv6interfacelocalallnodes, + net.IPv6linklocalallnodes, + net.IPv6linklocalallrouters, + net.ParseIP("fe01::42"), + net.ParseIP("fe02::4242"), + } { + t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { + conn.SendIPv6( + t, + testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, + testbench.UDP{}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index b754918f6..b47ddb6c3 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -72,7 +72,7 @@ func (e icmpError) ToICMPv4() *testbench.ICMPv4 { type errorDetection struct { name string useValidConn bool - f func(context.Context, testData) error + f func(context.Context, *testing.T, testData) } type testData struct { @@ -95,12 +95,14 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { } // sendICMPError sends an ICMP error message in response to a UDP datagram. -func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error { - layers := (*testbench.Connection)(conn).CreateFrame(nil) +func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) { + t.Helper() + + layers := (*testbench.Connection)(conn).CreateFrame(t, nil) layers = layers[:len(layers)-1] ip, ok := udp.Prev().(*testbench.IPv4) if !ok { - return fmt.Errorf("expected %s to be IPv4", udp.Prev()) + t.Fatalf("expected %s to be IPv4", udp.Prev()) } if icmpErr == timeToLiveExceeded { *ip.TTL = 1 @@ -114,84 +116,82 @@ func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UD // resulting in a mal-formed packet. layers = append(layers, icmpErr.ToICMPv4(), ip, udp) - (*testbench.Connection)(conn).SendFrameStateless(layers) - return nil + (*testbench.Connection)(conn).SendFrameStateless(t, layers) } // testRecv tests observing the ICMP error through the recv syscall. A packet // is sent to the DUT, and if wantErrno is non-zero, then the first recv should // fail and the second should succeed. Otherwise if wantErrno is zero then the // first recv should succeed immediately. -func testRecv(ctx context.Context, d testData) error { +func testRecv(ctx context.Context, t *testing.T, d testData) { + t.Helper() + // Check that receiving on the clean socket works. - d.conn.Send(testbench.UDP{DstPort: &d.cleanPort}) - d.dut.Recv(d.cleanFD, 100, 0) + d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort}) + d.dut.Recv(t, d.cleanFD, 100, 0) - d.conn.Send(testbench.UDP{}) + d.conn.Send(t, testbench.UDP{}) if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0) + ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) if ret != -1 { - return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.Recv(d.remoteFD, 100, 0) - return nil + d.dut.Recv(t, d.remoteFD, 100, 0) } // testSendTo tests observing the ICMP error through the send syscall. If // wantErrno is non-zero, the first send should fail and a subsequent send // should suceed; while if wantErrno is zero then the first send should just // succeed. -func testSendTo(ctx context.Context, d testData) error { +func testSendTo(ctx context.Context, t *testing.T, d testData) { // Check that sending on the clean socket works. - d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err) + d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err) } if d.wantErrno != syscall.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr()) + ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) if ret != -1 { - return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) + t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno) } if err != d.wantErrno { - return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) + t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno) } } - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } -func testSockOpt(_ context.Context, d testData) error { +func testSockOpt(_ context.Context, t *testing.T, d testData) { // Check that there's no pending error on the clean socket. - if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { - return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno) } - if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { - return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) + if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) } // Check that after clearing socket error, sending doesn't fail. - d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr()) - if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil { - return fmt.Errorf("did not receive UDP packet as expected: %s", err) + d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) + if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil { + t.Fatalf("did not receive UDP packet as expected: %s", err) } - return nil } // TestUDPICMPErrorPropagation tests that ICMP error messages in response to @@ -227,31 +227,29 @@ func TestUDPICMPErrorPropagation(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) errDetectConn := &conn if errDetect.useValidConn { @@ -260,14 +258,12 @@ func TestUDPICMPErrorPropagation(t *testing.T) { // interactions between it and the the DUT should be independent of // the ICMP error at least at the port level. connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer connClean.Close() + defer connClean.Close(t) errDetectConn = &connClean } - if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil { - t.Fatal(err) - } + errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}) }) } } @@ -285,24 +281,24 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(remoteFD) + remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, remoteFD) // Create a second, clean socket on the DUT to ensure that the ICMP // error messages only affect the sockets they are intended for. - cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(cleanFD) + cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4zero) + defer dut.Close(t, cleanFD) conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + defer conn.Close(t) if connect { - dut.Connect(remoteFD, conn.LocalAddr()) - dut.Connect(cleanFD, conn.LocalAddr()) + dut.Connect(t, remoteFD, conn.LocalAddr(t)) + dut.Connect(t, cleanFD, conn.LocalAddr(t)) } - dut.SendTo(remoteFD, nil, 0, conn.LocalAddr()) - udp, err := conn.Expect(testbench.UDP{}, time.Second) + dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t)) + udp, err := conn.Expect(t, testbench.UDP{}, time.Second) if err != nil { t.Fatalf("did not receive message from DUT: %s", err) } @@ -316,7 +312,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0) + ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0) if ret != -1 { t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno) return @@ -330,7 +326,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 { t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err) } }() @@ -341,7 +337,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 { + if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 { t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err) } }() @@ -352,12 +348,10 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { // alternative is available. time.Sleep(2 * time.Second) - if err := sendICMPError(&conn, icmpErr, udp); err != nil { - t.Fatal(err) - } + sendICMPError(t, &conn, icmpErr, udp) - conn.Send(testbench.UDP{DstPort: &cleanPort}) - conn.Send(testbench.UDP{}) + conn.Send(t, testbench.UDP{DstPort: &cleanPort}) + conn.Send(t, testbench.UDP{}) wg.Wait() }) } diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go new file mode 100644 index 000000000..526173969 --- /dev/null +++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_recv_mcast_bcast_test + +import ( + "context" + "flag" + "fmt" + "net" + "syscall" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestUDPRecvMcastBcast(t *testing.T) { + subnetBcastAddr := broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)) + + for _, v := range []struct { + bound, to net.IP + }{ + {bound: net.IPv4zero, to: subnetBcastAddr}, + {bound: net.IPv4zero, to: net.IPv4bcast}, + {bound: net.IPv4zero, to: net.IPv4allsys}, + + {bound: subnetBcastAddr, to: subnetBcastAddr}, + {bound: subnetBcastAddr, to: net.IPv4bcast}, + + {bound: net.IPv4bcast, to: net.IPv4bcast}, + {bound: net.IPv4allsys, to: net.IPv4allsys}, + } { + t.Run(fmt.Sprintf("bound=%s,to=%s", v.bound, v.to), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, v.bound) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(v.to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + got, want := dut.Recv(t, boundFD, int32(len(payload)+1), 0), payload + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + } +} + +func TestUDPDoesntRecvMcastBcastOnUnicastAddr(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4)) + dut.SetSockOptTimeval(t, boundFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &unix.Timeval{Sec: 1, Usec: 0}) + defer dut.Close(t, boundFD) + conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + defer conn.Close(t) + + for _, to := range []net.IP{ + broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)), + net.IPv4(255, 255, 255, 255), + net.IPv4(224, 0, 0, 1), + } { + t.Run(fmt.Sprint("to=%s", to), func(t *testing.T) { + payload := testbench.GenerateRandomPayload(t, 1<<10 /* 1 KiB */) + conn.SendIP( + t, + testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(to.To4()))}, + testbench.UDP{}, + &testbench.Payload{Bytes: payload}, + ) + ret, payload, errno := dut.RecvWithErrno(context.Background(), t, boundFD, 100, 0) + if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) + } + }) + } +} + +func broadcastAddr(ip net.IP, mask net.IPMask) net.IP { + ip4 := ip.To4() + for i := range ip4 { + ip4[i] |= ^mask[i] + } + return ip4 +} diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 224feef85..91b967400 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "gvisor.dev/gvisor/test/packetimpact/testbench" ) @@ -28,62 +29,75 @@ func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestUDPRecv(t *testing.T) { +type udpConn interface { + Send(*testing.T, testbench.UDP, ...testbench.Layer) + ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error) + Drain(*testing.T) + Close(*testing.T) +} + +func TestUDP(t *testing.T) { dut := testbench.NewDUT(t) defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) - conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) - if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want { - t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want) + for _, isIPv4 := range []bool{true, false} { + ipVersionName := "IPv6" + if isIPv4 { + ipVersionName = "IPv4" + } + t.Run(ipVersionName, func(t *testing.T) { + var addr string + if isIPv4 { + addr = testbench.RemoteIPv4 + } else { + addr = testbench.RemoteIPv6 } - }) - } -} + boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr)) + defer dut.Close(t, boundFD) -func TestUDPSend(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0")) - defer dut.Close(boundFD) - conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) - defer conn.Close() + var conn udpConn + var localAddr unix.Sockaddr + if isIPv4 { + v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v4Conn.LocalAddr(t) + conn = &v4Conn + } else { + v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort}) + localAddr = v6Conn.LocalAddr(t) + conn = &v6Conn + } + defer conn.Close(t) - testCases := []struct { - name string - payload []byte - }{ - {"emptypayload", nil}, - {"small payload", []byte("hello world")}, - {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, - // Even though UDP allows larger dgrams we don't test it here as - // they need to be fragmented and written out as individual - // frames. - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - conn.Drain() - if got, want := int(dut.SendTo(boundFD, tc.payload, 0, conn.LocalAddr())), len(tc.payload); got != want { - t.Fatalf("short write got: %d, want: %d", got, want) + testCases := []struct { + name string + payload []byte + }{ + {"emptypayload", nil}, + {"small payload", []byte("hello world")}, + {"1kPayload", testbench.GenerateRandomPayload(t, 1<<10)}, + // Even though UDP allows larger dgrams we don't test it here as + // they need to be fragmented and written out as individual + // frames. } - if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, 1*time.Second); err != nil { - t.Fatal(err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("Send", func(t *testing.T) { + conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload}) + got, want := dut.Recv(t, boundFD, int32(len(tc.payload)+1), 0), tc.payload + if diff := cmp.Diff(want, got); diff != "" { + t.Fatalf("received payload does not match sent payload, diff (-want, +got):\n%s", diff) + } + }) + t.Run("Recv", func(t *testing.T) { + conn.Drain(t) + if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want { + t.Fatalf("short write got: %d, want: %d", got, want) + } + if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil { + t.Fatal(err) + } + }) + }) } }) } diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc index 241f39896..e686041c9 100644 --- a/test/perf/linux/futex_benchmark.cc +++ b/test/perf/linux/futex_benchmark.cc @@ -84,9 +84,9 @@ BENCHMARK(BM_FutexWaitNop)->MinTime(5); // changes, such that it always times out. Timeout overhead can be estimated by // timer overruns for short timeouts. void BM_FutexWaitMonotonicTimeout(benchmark::State& state) { - const int timeout_ns = state.range(0); + const absl::Duration timeout = absl::Nanoseconds(state.range(0)); std::atomic<int32_t> v(0); - auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns)); + auto ts = absl::ToTimespec(timeout); for (auto _ : state) { TEST_PCHECK(FutexWaitMonotonicTimeout(&v, 0, &ts) == -1 && diff --git a/test/root/BUILD b/test/root/BUILD index a9e91ccd6..a9130b34f 100644 --- a/test/root/BUILD +++ b/test/root/BUILD @@ -41,7 +41,7 @@ go_test( "//runsc/container", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], @@ -51,8 +51,5 @@ vm_test( name = "root_vm_test", size = "large", shard_count = 1, - targets = [ - "//tools/installers:shim", - ":root_test", - ], + targets = [":root_test"], ) diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go index d0634b5c3..a26b83081 100644 --- a/test/root/cgroup_test.go +++ b/test/root/cgroup_test.go @@ -16,6 +16,7 @@ package root import ( "bufio" + "context" "fmt" "io/ioutil" "os" @@ -56,25 +57,24 @@ func verifyPid(pid int, path string) error { return fmt.Errorf("got: %v, want: %d", gots, pid) } -func TestMemCGroup(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() +func TestMemCgroup(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Start a new container and allocate the specified about of memory. allocMemSize := 128 << 20 allocMemLimit := 2 * allocMemSize - if err := d.Spawn(dockerutil.RunOpts{ - Image: "basic/python", - Memory: allocMemLimit / 1024, // Must be in Kb. - }, "python", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { + + if err := d.Spawn(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + Memory: allocMemLimit, // Must be in bytes. + }, "python3", "-c", fmt.Sprintf("import time; s = 'a' * %d; time.sleep(100)", allocMemSize)); err != nil { t.Fatalf("docker run failed: %v", err) } // Extract the ID to lookup the cgroup. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) - } + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Wait when the container will allocate memory. @@ -127,8 +127,9 @@ func TestMemCGroup(t *testing.T) { // TestCgroup sets cgroup options and checks that cgroup was properly configured. func TestCgroup(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // This is not a comprehensive list of attributes. // @@ -137,94 +138,133 @@ func TestCgroup(t *testing.T) { // are often run on a single core virtual machine, and there is only a single // CPU available in our current set, and every container's set. attrs := []struct { - arg string + field string + value int64 ctrl string file string want string skipIfNotFound bool }{ { - arg: "--cpu-shares=1000", - ctrl: "cpu", - file: "cpu.shares", - want: "1000", + field: "cpu-shares", + value: 1000, + ctrl: "cpu", + file: "cpu.shares", + want: "1000", }, { - arg: "--cpu-period=2000", - ctrl: "cpu", - file: "cpu.cfs_period_us", - want: "2000", + field: "cpu-period", + value: 2000, + ctrl: "cpu", + file: "cpu.cfs_period_us", + want: "2000", }, { - arg: "--cpu-quota=3000", - ctrl: "cpu", - file: "cpu.cfs_quota_us", - want: "3000", + field: "cpu-quota", + value: 3000, + ctrl: "cpu", + file: "cpu.cfs_quota_us", + want: "3000", }, { - arg: "--kernel-memory=100MB", - ctrl: "memory", - file: "memory.kmem.limit_in_bytes", - want: "104857600", + field: "kernel-memory", + value: 100 << 20, + ctrl: "memory", + file: "memory.kmem.limit_in_bytes", + want: "104857600", }, { - arg: "--memory=1GB", - ctrl: "memory", - file: "memory.limit_in_bytes", - want: "1073741824", + field: "memory", + value: 1 << 30, + ctrl: "memory", + file: "memory.limit_in_bytes", + want: "1073741824", }, { - arg: "--memory-reservation=500MB", - ctrl: "memory", - file: "memory.soft_limit_in_bytes", - want: "524288000", + field: "memory-reservation", + value: 500 << 20, + ctrl: "memory", + file: "memory.soft_limit_in_bytes", + want: "524288000", }, { - arg: "--memory-swap=2GB", + field: "memory-swap", + value: 2 << 30, ctrl: "memory", file: "memory.memsw.limit_in_bytes", want: "2147483648", skipIfNotFound: true, // swap may be disabled on the machine. }, { - arg: "--memory-swappiness=5", - ctrl: "memory", - file: "memory.swappiness", - want: "5", + field: "memory-swappiness", + value: 5, + ctrl: "memory", + file: "memory.swappiness", + want: "5", }, { - arg: "--blkio-weight=750", + field: "blkio-weight", + value: 750, ctrl: "blkio", file: "blkio.weight", want: "750", skipIfNotFound: true, // blkio groups may not be available. }, { - arg: "--pids-limit=1000", - ctrl: "pids", - file: "pids.max", - want: "1000", + field: "pids-limit", + value: 1000, + ctrl: "pids", + file: "pids.max", + want: "1000", }, } - args := make([]string, 0, len(attrs)) + // Make configs. + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ + Image: "basic/alpine", + }, "sleep", "10000") + + // Add Cgroup arguments to configs. for _, attr := range attrs { - args = append(args, attr.arg) + switch attr.field { + case "cpu-shares": + hostconf.Resources.CPUShares = attr.value + case "cpu-period": + hostconf.Resources.CPUPeriod = attr.value + case "cpu-quota": + hostconf.Resources.CPUQuota = attr.value + case "kernel-memory": + hostconf.Resources.KernelMemory = attr.value + case "memory": + hostconf.Resources.Memory = attr.value + case "memory-reservation": + hostconf.Resources.MemoryReservation = attr.value + case "memory-swap": + hostconf.Resources.MemorySwap = attr.value + case "memory-swappiness": + val := attr.value + hostconf.Resources.MemorySwappiness = &val + case "blkio-weight": + hostconf.Resources.BlkioWeight = uint16(attr.value) + case "pids-limit": + val := attr.value + hostconf.Resources.PidsLimit = &val + + } } - // Start the container. - if err := d.Spawn(dockerutil.RunOpts{ - Image: "basic/alpine", - Extra: args, // Cgroup arguments. - }, "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) + // Create container. + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - // Lookup the relevant cgroup ID. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + // Start container. + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Lookup the relevant cgroup ID. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check list of attributes defined above. @@ -239,7 +279,7 @@ func TestCgroup(t *testing.T) { t.Fatalf("failed to read %q: %v", path, err) } if got := strings.TrimSpace(string(out)); got != attr.want { - t.Errorf("arg: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.arg, attr.ctrl, attr.file, got, attr.want) + t.Errorf("field: %q, cgroup attribute %s/%s, got: %q, want: %q", attr.field, attr.ctrl, attr.file, got, attr.want) } } @@ -257,7 +297,7 @@ func TestCgroup(t *testing.T) { "pids", "systemd", } - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } @@ -269,29 +309,34 @@ func TestCgroup(t *testing.T) { } } -// TestCgroup sets cgroup options and checks that cgroup was properly configured. +// TestCgroupParent sets the "CgroupParent" option and checks that the child and parent's +// cgroups are created correctly relative to each other. func TestCgroupParent(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) // Construct a known cgroup name. parent := testutil.RandomID("runsc-") - if err := d.Spawn(dockerutil.RunOpts{ + conf, hostconf, _ := d.ConfigsFrom(dockerutil.RunOpts{ Image: "basic/alpine", - Extra: []string{fmt.Sprintf("--cgroup-parent=%s", parent)}, - }, "sleep", "10000"); err != nil { - t.Fatalf("docker run failed: %v", err) + }, "sleep", "10000") + hostconf.Resources.CgroupParent = parent + + if err := d.CreateFrom(ctx, conf, hostconf, nil); err != nil { + t.Fatalf("create failed with: %v", err) } - // Extract the ID to look up the cgroup. - gid, err := d.ID() - if err != nil { - t.Fatalf("Docker.ID() failed: %v", err) + if err := d.Start(ctx); err != nil { + t.Fatalf("start failed with: %v", err) } + + // Extract the ID to look up the cgroup. + gid := d.ID() t.Logf("cgroup ID: %s", gid) // Check that sandbox is inside cgroup. - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("SandboxPid: %v", err) } diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index a306132a4..58fcd6f08 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -16,6 +16,7 @@ package root import ( + "context" "fmt" "io/ioutil" "os/exec" @@ -30,16 +31,17 @@ import ( // TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned // up after the sandbox is destroyed. func TestChroot(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) } - pid, err := d.SandboxPid() + pid, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } @@ -75,14 +77,15 @@ func TestChroot(t *testing.T) { t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc") } - d.CleanUp() + d.CleanUp(ctx) } func TestChrootGofer(t *testing.T) { - d := dockerutil.MakeDocker(t) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) - if err := d.Spawn(dockerutil.RunOpts{ + if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/alpine", }, "sleep", "10000"); err != nil { t.Fatalf("docker run failed: %v", err) @@ -91,7 +94,7 @@ func TestChrootGofer(t *testing.T) { // It's tricky to find gofers. Get sandbox PID first, then find parent. From // parent get all immediate children, remove the sandbox, and everything else // are gofers. - sandPID, err := d.SandboxPid() + sandPID, err := d.SandboxPid(ctx) if err != nil { t.Fatalf("Docker.SandboxPid(): %v", err) } diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index 732fae821..df91fa0fe 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -20,13 +20,14 @@ import ( "fmt" "io" "io/ioutil" - "log" "net/http" "os" "os/exec" "path" - "path/filepath" + "regexp" + "strconv" "strings" + "sync" "testing" "time" @@ -75,6 +76,8 @@ func SimpleSpec(name, image string, cmd []string, extra map[string]interface{}) // Log files are not deleted after root tests are run. Log to random // paths to ensure logs are fresh. "log_path": fmt.Sprintf("%s.log", testutil.RandomID(name)), + "stdin": false, + "tty": false, } if len(cmd) > 0 { // Omit if empty. s["command"] = cmd @@ -95,25 +98,29 @@ var Httpd = SimpleSpec("httpd", "basic/httpd", nil, nil) // TestCrictlSanity refers to b/112433158. func TestCrictlSanity(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), Httpd) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - // Look for the httpd page. - if err = httpGet(crictl, podID, "index.html"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), Httpd) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + // Look for the httpd page. + if err = httpGet(crictl, podID, "index.html"); err != nil { + t.Fatalf("failed to get page: %v", err) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } @@ -147,146 +154,179 @@ var HttpdMountPaths = SimpleSpec("httpd", "basic/httpd", nil, map[string]interfa // TestMountPaths refers to b/117635704. func TestMountPaths(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - podID, contID, err := crictl.StartPodAndContainer("basic/httpd", Sandbox("default"), HttpdMountPaths) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - // Look for the directory available at /test. - if err = httpGet(crictl, podID, "test"); err != nil { - t.Fatalf("failed to get page: %v", err) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/httpd", Sandbox("default"), HttpdMountPaths) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + // Look for the directory available at /test. + if err = httpGet(crictl, podID, "test"); err != nil { + t.Fatalf("failed to get page: %v", err) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestMountPaths refers to b/118728671. func TestMountOverSymlinks(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) - } - defer cleanup() - - spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/resolv", Sandbox("default"), spec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") - if err != nil { - t.Fatalf("readlink failed: %v, out: %s", err, out) - } - if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { - t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) - } - - etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") - if err != nil { - t.Fatalf("cat failed: %v, out: %s", err, etc) - } - tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") - if err != nil { - t.Fatalf("cat failed: %v, out: %s", err, out) - } - if tmp != etc { - t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + + spec := SimpleSpec("busybox", "basic/resolv", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/resolv", Sandbox("default"), spec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Exec(contID, "readlink", "/etc/resolv.conf") + if err != nil { + t.Fatalf("readlink failed: %v, out: %s", err, out) + } + if want := "/tmp/resolv.conf"; !strings.Contains(string(out), want) { + t.Fatalf("/etc/resolv.conf is not pointing to %q: %q", want, string(out)) + } + + etc, err := crictl.Exec(contID, "cat", "/etc/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, etc) + } + tmp, err := crictl.Exec(contID, "cat", "/tmp/resolv.conf") + if err != nil { + t.Fatalf("cat failed: %v, out: %s", err, out) + } + if tmp != etc { + t.Fatalf("file content doesn't match:\n\t/etc/resolv.conf: %s\n\t/tmp/resolv.conf: %s", string(etc), string(tmp)) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) } } // TestHomeDir tests that the HOME environment variable is set for // Pod containers. func TestHomeDir(t *testing.T) { - // Setup containerd and crictl. - crictl, cleanup, err := setup(t) - if err != nil { - t.Fatalf("failed to setup crictl: %v", err) + for _, version := range allVersions { + t.Run(version, func(t *testing.T) { + // Setup containerd and crictl. + crictl, cleanup, err := setup(t, version) + if err != nil { + t.Fatalf("failed to setup crictl: %v", err) + } + defer cleanup() + + // Note that container ID returned here is a sub-container. All Pod + // containers are sub-containers. The root container of the sandbox is the + // pause container. + t.Run("sub-container", func(t *testing.T) { + contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("subcont-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Logs(contID) + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + + // Stop everything; note that the pod may have already stopped. + crictl.StopPodAndContainer(podID, contID) + }) + + // Tests that HOME is set for the exec process. + t.Run("exec", func(t *testing.T) { + contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil) + podID, contID, err := crictl.StartPodAndContainer(containerdRuntime, "basic/busybox", Sandbox("exec-sandbox"), contSpec) + if err != nil { + t.Fatalf("start failed: %v", err) + } + + out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") + if err != nil { + t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) + } + if got, want := strings.TrimSpace(string(out)), "/root"; got != want { + t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) + } + + // Stop everything. + if err := crictl.StopPodAndContainer(podID, contID); err != nil { + t.Fatalf("stop failed: %v", err) + } + }) + }) } - defer cleanup() - - // Note that container ID returned here is a sub-container. All Pod - // containers are sub-containers. The root container of the sandbox is the - // pause container. - t.Run("sub-container", func(t *testing.T) { - contSpec := SimpleSpec("subcontainer", "basic/busybox", []string{"sh", "-c", "echo $HOME"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("subcont-sandbox"), contSpec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Logs(contID) - if err != nil { - t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) - } - }) - - // Tests that HOME is set for the exec process. - t.Run("exec", func(t *testing.T) { - contSpec := SimpleSpec("exec", "basic/busybox", []string{"sleep", "1000"}, nil) - podID, contID, err := crictl.StartPodAndContainer("basic/busybox", Sandbox("exec-sandbox"), contSpec) - if err != nil { - t.Fatalf("start failed: %v", err) - } - - out, err := crictl.Exec(contID, "sh", "-c", "echo $HOME") - if err != nil { - t.Fatalf("failed retrieving container logs: %v, out: %s", err, out) - } - if got, want := strings.TrimSpace(string(out)), "/root"; got != want { - t.Fatalf("Home directory invalid. Got %q, Want : %q", got, want) - } - - // Stop everything. - if err := crictl.StopPodAndContainer(podID, contID); err != nil { - t.Fatalf("stop failed: %v", err) - } - }) } -// containerdConfigTemplate is a .toml config for containerd. It contains a -// formatting verb so the runtime field can be set via fmt.Sprintf. -const containerdConfigTemplate = ` +const containerdRuntime = "runsc" + +const v1Template = ` disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true [plugins.linux] - runtime = "%s" - runtime_root = "/tmp/test-containerd/runsc" - shim = "/usr/local/bin/gvisor-containerd-shim" + shim = "%s" shim_debug = true - -[plugins.cri.containerd.runtimes.runsc] +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] runtime_type = "io.containerd.runtime.v1.linux" runtime_engine = "%s" + runtime_root = "%s/root/runsc" ` +const v2Template = ` +disabled_plugins = ["restart"] +[plugins.cri] + disable_tcp_service = true +[plugins.linux] + shim_debug = true +[plugins.cri.containerd.runtimes.` + containerdRuntime + `] + runtime_type = "io.containerd.` + containerdRuntime + `.v1" +[plugins.cri.containerd.runtimes.` + containerdRuntime + `.options] + TypeUrl = "io.containerd.` + containerdRuntime + `.v1.options" +` + +const ( + // v1 is the containerd API v1. + v1 string = "v1" + + // v1 is the containerd API v21. + v2 string = "v2" +) + +// allVersions is the set of known versions. +var allVersions = []string{v1, v2} + // setup sets up before a test. Specifically it: // * Creates directories and a socket for containerd to utilize. // * Runs containerd and waits for it to reach a "ready" state for testing. // * Returns a cleanup function that should be called at the end of the test. -func setup(t *testing.T) (*criutil.Crictl, func(), error) { +func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) { // Create temporary containerd root and state directories, and a socket // via which crictl and containerd communicate. containerdRoot, err := ioutil.TempDir(testutil.TmpDir(), "containerd-root") @@ -295,13 +335,43 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { } cu := cleanup.Make(func() { os.RemoveAll(containerdRoot) }) defer cu.Clean() + t.Logf("Using containerd root: %s", containerdRoot) containerdState, err := ioutil.TempDir(testutil.TmpDir(), "containerd-state") if err != nil { t.Fatalf("failed to create containerd state: %v", err) } cu.Add(func() { os.RemoveAll(containerdState) }) - sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock") + t.Logf("Using containerd state: %s", containerdState) + + sockDir, err := ioutil.TempDir(testutil.TmpDir(), "containerd-sock") + if err != nil { + t.Fatalf("failed to create containerd socket directory: %v", err) + } + cu.Add(func() { os.RemoveAll(sockDir) }) + sockAddr := path.Join(sockDir, "test.sock") + t.Logf("Using containerd socket: %s", sockAddr) + + // Extract the containerd version. + versionCmd := exec.Command(getContainerd(), "-v") + out, err := versionCmd.CombinedOutput() + if err != nil { + t.Fatalf("error extracting containerd version: %v (%s)", err, string(out)) + } + r := regexp.MustCompile(" v([0-9]+)\\.([0-9]+)\\.([0-9+])") + vs := r.FindStringSubmatch(string(out)) + if len(vs) != 4 { + t.Fatalf("error unexpected version string: %s", string(out)) + } + major, err := strconv.ParseUint(vs[1], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd major version: %v (%s)", err, string(out)) + } + minor, err := strconv.ParseUint(vs[2], 10, 64) + if err != nil { + t.Fatalf("error parsing containerd minor version: %v (%s)", err, string(out)) + } + t.Logf("Using containerd version: %d.%d", major, minor) // We rewrite a configuration. This is based on the current docker // configuration for the runtime under test. @@ -309,28 +379,97 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { if err != nil { t.Fatalf("error discovering runtime path: %v", err) } - config, configCleanup, err := testutil.WriteTmpFile("containerd-config", fmt.Sprintf(containerdConfigTemplate, runtime, runtime)) + t.Logf("Using runtime: %v", runtime) + + // Construct a PATH that includes the runtime directory. This is + // because the shims will be installed there, and containerd may infer + // the binary name and search the PATH. + runtimeDir := path.Dir(runtime) + modifiedPath := os.Getenv("PATH") + if modifiedPath != "" { + modifiedPath = ":" + modifiedPath // We prepend below. + } + modifiedPath = path.Dir(getContainerd()) + modifiedPath + modifiedPath = runtimeDir + ":" + modifiedPath + t.Logf("Using PATH: %v", modifiedPath) + + var ( + config string + runpArgs []string + ) + switch version { + case v1: + // This is only supported less than 1.3. + if major > 1 || (major == 1 && minor >= 3) { + t.Skipf("skipping unsupported containerd (want less than 1.3, got %d.%d)", major, minor) + } + + // We provide the shim, followed by the runtime, and then a + // temporary root directory. + config = fmt.Sprintf(v1Template, criutil.ResolvePath("gvisor-containerd-shim"), runtime, containerdRoot) + case v2: + // This is only supported past 1.2. + if major < 1 || (major == 1 && minor <= 1) { + t.Skipf("skipping incompatible containerd (want at least 1.2, got %d.%d)", major, minor) + } + + // The runtime is provided via parameter. Note that the v2 shim + // binary name is always containerd-shim-* so we don't actually + // care about the docker runtime name. + config = v2Template + default: + t.Fatalf("unknown version: %d", version) + } + t.Logf("Using config: %s", config) + + // Generate the configuration for the test. + configFile, configCleanup, err := testutil.WriteTmpFile("containerd-config", config) if err != nil { t.Fatalf("failed to write containerd config") } cu.Add(configCleanup) // Start containerd. - cmd := exec.Command(getContainerd(), - "--config", config, + args := []string{ + getContainerd(), + "--config", configFile, "--log-level", "debug", "--root", containerdRoot, "--state", containerdState, - "--address", sockAddr) + "--address", sockAddr, + } + t.Logf("Using args: %s", strings.Join(args, " ")) + cmd := exec.Command(args[0], args[1:]...) + cmd.Env = append(os.Environ(), "PATH="+modifiedPath) + + // Include output in logs. + stderrPipe, err := cmd.StderrPipe() + if err != nil { + t.Fatalf("failed to create stderr pipe: %v", err) + } + cu.Add(func() { stderrPipe.Close() }) + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + t.Fatalf("failed to create stdout pipe: %v", err) + } + cu.Add(func() { stdoutPipe.Close() }) + var ( + wg sync.WaitGroup + stderr bytes.Buffer + stdout bytes.Buffer + ) startupR, startupW := io.Pipe() - defer startupR.Close() - defer startupW.Close() - stderr := &bytes.Buffer{} - stdout := &bytes.Buffer{} - cmd.Stderr = io.MultiWriter(startupW, stderr) - cmd.Stdout = io.MultiWriter(startupW, stdout) + wg.Add(2) + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stderr), stderrPipe) + }() + go func() { + defer wg.Done() + io.Copy(io.MultiWriter(startupW, &stdout), stdoutPipe) + }() cu.Add(func() { - // Log output in case of failure. + wg.Wait() t.Logf("containerd stdout: %s", stdout.String()) t.Logf("containerd stderr: %s", stderr.String()) }) @@ -345,13 +484,17 @@ func setup(t *testing.T) (*criutil.Crictl, func(), error) { t.Fatalf("failed to start containerd: %v", err) } + // Discard all subsequent data. + go io.Copy(ioutil.Discard, startupR) + + // Create the crictl interface. + cc := criutil.NewCrictl(t, sockAddr, runpArgs) + cu.Add(cc.CleanUp) + // Kill must be the last cleanup (as it will be executed first). - cc := criutil.NewCrictl(t, sockAddr) cu.Add(func() { - cc.CleanUp() // Remove tmp files, etc. - if err := testutil.KillCommand(cmd); err != nil { - log.Printf("error killing containerd: %v", err) - } + // Best effort: ignore errors. + testutil.KillCommand(cmd) }) return cc, cu.Release(), nil diff --git a/test/runner/BUILD b/test/runner/BUILD index 6833c9986..63c7ec83a 100644 --- a/test/runner/BUILD +++ b/test/runner/BUILD @@ -16,7 +16,8 @@ go_binary( "//runsc/specutils", "//test/runner/gtest", "//test/uds", - "@com_github_opencontainers_runtime-spec//specs-go:go_default_library", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@com_github_syndtr_gocapability//capability:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 5a83f8060..c92392b35 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -61,7 +61,8 @@ def _syscall_test( file_access = "exclusive", overlay = False, add_uds_tree = False, - vfs2 = False): + vfs2 = False, + fuse = False): # Prepend "runsc" to non-native platform names. full_platform = platform if platform == "native" else "runsc_" + platform @@ -73,6 +74,8 @@ def _syscall_test( name += "_overlay" if vfs2: name += "_vfs2" + if fuse: + name += "_fuse" if network != "none": name += "_" + network + "net" @@ -107,6 +110,7 @@ def _syscall_test( "--overlay=" + str(overlay), "--add-uds-tree=" + str(add_uds_tree), "--vfs2=" + str(vfs2), + "--fuse=" + str(fuse), ] # Call the rule above. @@ -129,6 +133,7 @@ def syscall_test( add_uds_tree = False, add_hostinet = False, vfs2 = False, + fuse = False, tags = None): """syscall_test is a macro that will create targets for all platforms. @@ -152,7 +157,7 @@ def syscall_test( platform = "native", use_tmpfs = False, add_uds_tree = add_uds_tree, - tags = tags, + tags = list(tags), ) for (platform, platform_tags) in platforms.items(): @@ -188,6 +193,19 @@ def syscall_test( vfs2 = True, ) + if vfs2 and fuse: + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = default_platform, + use_tmpfs = use_tmpfs, + add_uds_tree = add_uds_tree, + tags = platforms[default_platform] + vfs2_tags, + vfs2 = True, + fuse = True, + ) + # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. if add_overlay: _syscall_test( @@ -195,7 +213,7 @@ def syscall_test( shard_count = shard_count, size = size, platform = default_platform, - use_tmpfs = False, # overlay is adding a writable tmpfs on top of root. + use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, tags = platforms[default_platform] + tags, overlay = True, diff --git a/test/runner/runner.go b/test/runner/runner.go index 5456e46a6..bc4b39cbb 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -30,6 +30,7 @@ import ( "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "github.com/syndtr/gocapability/capability" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/testutil" @@ -47,6 +48,7 @@ var ( fileAccess = flag.String("file-access", "exclusive", "mounts root in exclusive or shared mode") overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable tmpfs overlay") vfs2 = flag.Bool("vfs2", false, "enable VFS2") + fuse = flag.Bool("fuse", false, "enable FUSE") parallel = flag.Bool("parallel", false, "run tests in parallel") runscPath = flag.String("runsc", "", "path to runsc binary") @@ -104,6 +106,13 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + + if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { + cmd.SysProcAttr = &syscall.SysProcAttr{ + Cloneflags: syscall.CLONE_NEWNET, + } + } + if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) @@ -149,6 +158,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { } if *vfs2 { args = append(args, "-vfs2") + if *fuse { + args = append(args, "-fuse") + } } if *debug { args = append(args, "-debug", "-log-packets=true") @@ -358,6 +370,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { vfsVar := "GVISOR_VFS" if *vfs2 { env = append(env, vfsVar+"=VFS2") + fuseVar := "FUSE_ENABLED" + if *fuse { + env = append(env, fuseVar+"=TRUE") + } else { + env = append(env, fuseVar+"=FALSE") + } } else { env = append(env, vfsVar+"=VFS1") } diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 022de5ff7..1728161ce 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -6,28 +6,33 @@ runtime_test( name = "go1.12", exclude_file = "exclude_go1.12.csv", lang = "go", + shard_count = 5, ) runtime_test( name = "java11", exclude_file = "exclude_java11.csv", lang = "java", + shard_count = 20, ) runtime_test( name = "nodejs12.4.0", exclude_file = "exclude_nodejs12.4.0.csv", lang = "nodejs", + shard_count = 10, ) runtime_test( name = "php7.3.6", exclude_file = "exclude_php7.3.6.csv", lang = "php", + shard_count = 5, ) runtime_test( name = "python3.7.3", exclude_file = "exclude_python3.7.3.csv", lang = "python", + shard_count = 5, ) diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl index dc3667f05..5779b9591 100644 --- a/test/runtimes/defs.bzl +++ b/test/runtimes/defs.bzl @@ -20,7 +20,7 @@ def _runtime_test_impl(ctx): runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) runner_content = "\n".join([ "#!/bin/bash", - "%s %s\n" % (ctx.files._runner[0].short_path, " ".join(args)), + "%s %s $@\n" % (ctx.files._runner[0].short_path, " ".join(args)), ]) ctx.actions.write(runner, runner_content, is_executable = True) @@ -65,6 +65,7 @@ def runtime_test(name, **kwargs): "local", "manual", ], + size = "enormous", **kwargs ) diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude_go1.12.csv index 8c8ae0c5d..81e02cf64 100644 --- a/test/runtimes/exclude_go1.12.csv +++ b/test/runtimes/exclude_go1.12.csv @@ -2,15 +2,12 @@ test name,bug id,comment cgo_errors,,FLAKY cgo_test,,FLAKY go_test:cmd/go,,FLAKY -go_test:cmd/vendor/golang.org/x/sys/unix,b/118783622,/dev devices missing -go_test:net,b/118784196,socket: invalid argument. Works as intended: see bug. +go_test:net,b/162473575,setsockopt: protocol not available. go_test:os,b/118780122,we have a pollable filesystem but that's a surprise -go_test:os/signal,b/118780860,/dev/pts not properly supported -go_test:runtime,b/118782341,sigtrap not reported or caught or something -go_test:syscall,b/118781998,bad bytes -- bad mem addr -race,b/118782931,thread sanitizer. Works as intended: b/62219744. +go_test:os/signal,b/118780860,/dev/pts not properly supported. Also being tracked in b/29356795. +go_test:runtime,b/118782341,sigtrap not reported or caught or something. Also being tracked in b/33003106. +go_test:syscall,b/118781998,bad bytes -- bad mem addr; FcntlFlock(F_GETLK) not supported. runtime:cpu124,b/118778254,segmentation fault test:0_1,,FLAKY -testasan,, testcarchive,b/118782924,no sigpipe testshared,,FLAKY diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv index 4ab4e2927..4eb0a4807 100644 --- a/test/runtimes/exclude_nodejs12.4.0.csv +++ b/test/runtimes/exclude_nodejs12.4.0.csv @@ -9,6 +9,8 @@ fixtures/test-fs-stat-sync-overflow.js,, internet/test-dgram-broadcast-multi-process.js,, internet/test-dgram-multicast-multi-process.js,, internet/test-dgram-multicast-set-interface-lo.js,, +internet/test-doctool-versions.js,, +internet/test-uv-threadpool-schedule.js,, parallel/test-cluster-dgram-reuse.js,b/64024294, parallel/test-dgram-bind-fd.js,b/132447356, parallel/test-dgram-create-socket-handle-fd.js,b/132447238, @@ -45,3 +47,5 @@ pseudo-tty/test-tty-window-size.js,, pseudo-tty/test-tty-wrap.js,, pummel/test-net-pingpong.js,, pummel/test-vm-memleak.js,, +sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout +tick-processor/test-tick-processor-builtin.js,, diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude_php7.3.6.csv index 456bf7487..0bef786c0 100644 --- a/test/runtimes/exclude_php7.3.6.csv +++ b/test/runtimes/exclude_php7.3.6.csv @@ -8,6 +8,10 @@ ext/mbstring/tests/bug77165.phpt,, ext/mbstring/tests/bug77454.phpt,, ext/mbstring/tests/mb_convert_encoding_leak.phpt,, ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/session/tests/session_set_save_handler_class_018.phpt,, +ext/session/tests/session_set_save_handler_iface_003.phpt,, +ext/session/tests/session_set_save_handler_variation4.phpt,, +ext/session/tests/session_set_save_handler_variation5.phpt,, ext/standard/tests/file/filetype_variation.phpt,, ext/standard/tests/file/fopen_variation19.phpt,, ext/standard/tests/file/php_fd_wrapper_01.phpt,, @@ -21,9 +25,13 @@ ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, ext/standard/tests/network/bug20134.phpt,, +ext/standard/tests/streams/stream_socket_sendto.phpt,, +ext/standard/tests/strings/007.phpt,, +sapi/cli/tests/upload_2G.phpt,, tests/output/stream_isatty_err.phpt,b/68720279, tests/output/stream_isatty_in-err.phpt,b/68720282, tests/output/stream_isatty_in-out-err.phpt,, tests/output/stream_isatty_in-out.phpt,b/68720299, tests/output/stream_isatty_out-err.phpt,b/68720311, tests/output/stream_isatty_out.phpt,b/68720325, +Zend/tests/concat_003.phpt,, diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/go.go index 3e2d5d8db..d0ae844e6 100644 --- a/test/runtimes/proctor/go.go +++ b/test/runtimes/proctor/go.go @@ -74,17 +74,26 @@ func (goRunner) ListTests() ([]string, error) { return append(toolSlice, diskFiltered...), nil } -// TestCmd implements TestRunner.TestCmd. -func (goRunner) TestCmd(test string) *exec.Cmd { - // Check if test exists on disk by searching for file of the same name. - // This will determine whether or not it is a Go test on disk. - if strings.HasSuffix(test, ".go") { - // Test has suffix ".go" which indicates a disk test, run it as such. - cmd := exec.Command("go", "run", "run.go", "-v", "--", test) +// TestCmds implements TestRunner.TestCmds. +func (goRunner) TestCmds(tests []string) []*exec.Cmd { + var toolTests, onDiskTests []string + for _, test := range tests { + if strings.HasSuffix(test, ".go") { + onDiskTests = append(onDiskTests, test) + } else { + toolTests = append(toolTests, "^"+test+"$") + } + } + + var cmds []*exec.Cmd + if len(toolTests) > 0 { + cmds = append(cmds, exec.Command("go", "tool", "dist", "test", "-v", "-no-rebuild", "-run", strings.Join(toolTests, "\\|"))) + } + if len(onDiskTests) > 0 { + cmd := exec.Command("go", append([]string{"run", "run.go", "-v", "--"}, onDiskTests...)...) cmd.Dir = goTestDir - return cmd + cmds = append(cmds, cmd) } - // No ".go" suffix, run as a tool test. - return exec.Command("go", "tool", "dist", "test", "-run", test) + return cmds } diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/java.go index 8b362029d..737fbe23e 100644 --- a/test/runtimes/proctor/java.go +++ b/test/runtimes/proctor/java.go @@ -60,12 +60,14 @@ func (javaRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (javaRunner) TestCmd(test string) *exec.Cmd { - args := []string{ - "-noreport", - "-dir:" + javaTestDir, - test, - } - return exec.Command("jtreg", args...) +// TestCmds implements TestRunner.TestCmds. +func (javaRunner) TestCmds(tests []string) []*exec.Cmd { + args := append( + []string{ + "-noreport", + "-dir:" + javaTestDir, + }, + tests..., + ) + return []*exec.Cmd{exec.Command("jtreg", args...)} } diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/nodejs.go index bd57db444..23d6edc72 100644 --- a/test/runtimes/proctor/nodejs.go +++ b/test/runtimes/proctor/nodejs.go @@ -39,8 +39,8 @@ func (nodejsRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (nodejsRunner) TestCmd(test string) *exec.Cmd { - args := []string{filepath.Join("tools", "test.py"), test} - return exec.Command("/usr/bin/python", args...) +// TestCmds implements TestRunner.TestCmds. +func (nodejsRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{filepath.Join("tools", "test.py")}, tests...) + return []*exec.Cmd{exec.Command("/usr/bin/python", args...)} } diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/php.go index 9115040e1..6a83d64e3 100644 --- a/test/runtimes/proctor/php.go +++ b/test/runtimes/proctor/php.go @@ -17,6 +17,7 @@ package main import ( "os/exec" "regexp" + "strings" ) var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`) @@ -35,8 +36,8 @@ func (phpRunner) ListTests() ([]string, error) { return testSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (phpRunner) TestCmd(test string) *exec.Cmd { - args := []string{"test", "TESTS=" + test} - return exec.Command("make", args...) +// TestCmds implements TestRunner.TestCmds. +func (phpRunner) TestCmds(tests []string) []*exec.Cmd { + args := []string{"test", "TESTS=" + strings.Join(tests, " ")} + return []*exec.Cmd{exec.Command("make", args...)} } diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/proctor.go index b54abe434..9e0642424 100644 --- a/test/runtimes/proctor/proctor.go +++ b/test/runtimes/proctor/proctor.go @@ -25,6 +25,7 @@ import ( "os/signal" "path/filepath" "regexp" + "strings" "syscall" ) @@ -34,15 +35,18 @@ type TestRunner interface { // ListTests returns a string slice of tests available to run. ListTests() ([]string, error) - // TestCmd returns an *exec.Cmd that will run the given test. - TestCmd(test string) *exec.Cmd + // TestCmds returns a slice of *exec.Cmd that will run the given tests. + // There is no correlation between the number of exec.Cmds returned and the + // number of tests. It could return one command to run all tests or a few + // commands that collectively run all. + TestCmds(tests []string) []*exec.Cmd } var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - testName = flag.String("test", "", "run a single test from the list of available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testNames = flag.String("tests", "", "run a subset of the available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) func main() { @@ -75,18 +79,20 @@ func main() { } var tests []string - if *testName == "" { + if *testNames == "" { // Run every test. tests, err = tr.ListTests() if err != nil { log.Fatalf("failed to get all tests: %v", err) } } else { - // Run a single test. - tests = []string{*testName} + // Run subset of test. + tests = strings.Split(*testNames, ",") } - for _, test := range tests { - cmd := tr.TestCmd(test) + + // Run tests. + cmds := tr.TestCmds(tests) + for _, cmd := range cmds { cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr if err := cmd.Run(); err != nil { log.Fatalf("FAIL: %v", err) diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/python.go index b9e0fbe6f..7c598801b 100644 --- a/test/runtimes/proctor/python.go +++ b/test/runtimes/proctor/python.go @@ -42,8 +42,8 @@ func (pythonRunner) ListTests() ([]string, error) { return toolSlice, nil } -// TestCmd implements TestRunner.TestCmd. -func (pythonRunner) TestCmd(test string) *exec.Cmd { - args := []string{"-m", "test", test} - return exec.Command("./python", args...) +// TestCmds implements TestRunner.TestCmds. +func (pythonRunner) TestCmds(tests []string) []*exec.Cmd { + args := append([]string{"-m", "test"}, tests...) + return []*exec.Cmd{exec.Command("./python", args...)} } diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD index 3972244b9..dc0d5d5b4 100644 --- a/test/runtimes/runner/BUILD +++ b/test/runtimes/runner/BUILD @@ -8,6 +8,7 @@ go_binary( srcs = ["main.go"], visibility = ["//test/runtimes:__pkg__"], deps = [ + "//pkg/log", "//pkg/test/dockerutil", "//pkg/test/testutil", ], diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/exclude_test.go index c08755894..67c2170c8 100644 --- a/test/runtimes/runner/exclude_test.go +++ b/test/runtimes/runner/exclude_test.go @@ -26,7 +26,7 @@ func TestMain(m *testing.M) { } // Test that the exclude file parses without error. -func TestBlacklists(t *testing.T) { +func TestExcludelist(t *testing.T) { ex, err := getExcludes() if err != nil { t.Fatalf("error parsing exclude file: %v", err) diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index 54d1169ef..e230912c9 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -16,6 +16,7 @@ package main import ( + "context" "encoding/csv" "flag" "fmt" @@ -26,6 +27,7 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -34,10 +36,11 @@ var ( lang = flag.String("lang", "", "language runtime to test") image = flag.String("image", "", "docker image with runtime tests") excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") + batchSize = flag.Int("batch", 50, "number of test cases run in one command") ) // Wait time for each test to run. -const timeout = 5 * time.Minute +const timeout = 45 * time.Minute func main() { flag.Parse() @@ -60,13 +63,19 @@ func runTests() int { } // Construct the shared docker instance. - d := dockerutil.MakeDocker(testutil.DefaultLogger(*lang)) - defer d.CleanUp() + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang)) + defer d.CleanUp(ctx) + + if err := testutil.TouchShardStatusFile(); err != nil { + fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) + return 1 + } // Get a slice of tests to run. This will also start a single Docker // container that will be used to run each test. The final test will // stop the Docker container. - tests, err := getTests(d, excludes) + tests, err := getTests(ctx, d, excludes) if err != nil { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) return 1 @@ -77,18 +86,18 @@ func runTests() int { } // getTests executes all tests as table tests. -func getTests(d *dockerutil.Docker, excludes map[string]struct{}) ([]testing.InternalTest, error) { +func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) { // Start the container. opts := dockerutil.RunOpts{ Image: fmt.Sprintf("runtimes/%s", *image), } d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") - if err := d.Spawn(opts, "/proctor/proctor", "--pause"); err != nil { + if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { return nil, fmt.Errorf("docker run failed: %v", err) } // Get a list of all tests in the image. - list, err := d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") + list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") if err != nil { return nil, fmt.Errorf("docker exec failed: %v", err) } @@ -103,17 +112,23 @@ func getTests(d *dockerutil.Docker, excludes map[string]struct{}) ([]testing.Int } var itests []testing.InternalTest - for _, tci := range indices { - // Capture tc in this scope. - tc := tests[tci] + for i := 0; i < len(indices); i += *batchSize { + var tcs []string + end := i + *batchSize + if end > len(indices) { + end = len(indices) + } + for _, tc := range indices[i:end] { + // Add test if not excluded. + if _, ok := excludes[tests[tc]]; ok { + log.Infof("Skipping test case %s\n", tests[tc]) + continue + } + tcs = append(tcs, tests[tc]) + } itests = append(itests, testing.InternalTest{ - Name: tc, + Name: strings.Join(tcs, ", "), F: func(t *testing.T) { - // Is the test excluded? - if _, ok := excludes[tc]; ok { - t.Skipf("SKIP: excluded test %q", tc) - } - var ( now = time.Now() done = make(chan struct{}) @@ -122,20 +137,20 @@ func getTests(d *dockerutil.Docker, excludes map[string]struct{}) ([]testing.Int ) go func() { - fmt.Printf("RUNNING %s...\n", tc) - output, err = d.Exec(dockerutil.RunOpts{}, "/proctor/proctor", "--runtime", *lang, "--test", tc) + fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ",")) close(done) }() select { case <-done: if err == nil { - fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now)) + fmt.Printf("PASS: (%v)\n\n", time.Since(now)) return } - t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output) + t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) case <-time.After(timeout): - t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output) + t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) } }, }) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 1638a11c7..c19b30b4a 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -57,6 +57,7 @@ syscall_test( size = "large", add_overlay = True, test = "//test/syscalls/linux:bind_test", + vfs2 = "True", ) syscall_test( @@ -145,7 +146,9 @@ syscall_test( ) syscall_test( + fuse = "True", test = "//test/syscalls/linux:dev_test", + vfs2 = "True", ) syscall_test( @@ -180,6 +183,7 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_binary_test", + vfs2 = "True", ) syscall_test( @@ -190,11 +194,13 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:fadvise64_test", + vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fallocate_test", + vfs2 = "True", ) syscall_test( @@ -238,6 +244,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:fsync_test", + vfs2 = "True", ) syscall_test( @@ -260,6 +267,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:getdents_test", + vfs2 = "True", ) syscall_test( @@ -276,12 +284,14 @@ syscall_test( size = "medium", add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. test = "//test/syscalls/linux:inotify_test", + vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:ioctl_test", + vfs2 = "True", ) syscall_test( @@ -305,6 +315,7 @@ syscall_test( add_overlay = True, test = "//test/syscalls/linux:link_test", use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) + vfs2 = "True", ) syscall_test( @@ -342,12 +353,14 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:mknod_test", + vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:mmap_test", + vfs2 = "True", ) syscall_test( @@ -387,6 +400,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_test", + vfs2 = "True", ) syscall_test( @@ -461,6 +475,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv2_test", + vfs2 = "True", ) syscall_test( @@ -471,6 +486,7 @@ syscall_test( syscall_test( size = "medium", test = "//test/syscalls/linux:proc_test", + vfs2 = "True", ) syscall_test( @@ -490,6 +506,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:proc_pid_uid_gid_map_test", + vfs2 = "True", ) syscall_test( @@ -518,6 +535,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", + vfs2 = "True", ) syscall_test( @@ -537,7 +555,7 @@ syscall_test( ) syscall_test( - test = "//test/syscalls/linux:raw_socket_ipv4_test", + test = "//test/syscalls/linux:raw_socket_test", vfs2 = "True", ) @@ -550,6 +568,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:readahead_test", + vfs2 = "True", ) syscall_test( @@ -570,6 +589,7 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:rename_test", + vfs2 = "True", ) syscall_test( @@ -621,11 +641,13 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_socket_test", + vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_test", + vfs2 = "True", ) syscall_test( @@ -699,6 +721,7 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", + vfs2 = "True", ) syscall_test( @@ -706,12 +729,14 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", + vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", + vfs2 = "True", ) syscall_test( @@ -720,6 +745,7 @@ syscall_test( # Takes too long for TSAN. Creates a lot of TCP sockets. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", + vfs2 = "True", ) syscall_test( @@ -795,6 +821,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:socket_blocking_local_test", + vfs2 = "True", ) syscall_test( @@ -804,6 +831,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", + vfs2 = "True", ) syscall_test( @@ -814,6 +842,7 @@ syscall_test( syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_local_test", + vfs2 = "True", ) syscall_test( @@ -825,11 +854,13 @@ syscall_test( syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_local_test", + vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_nonblock_local_test", + vfs2 = "True", ) syscall_test( @@ -837,6 +868,7 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_dgram_local_test", + vfs2 = "True", ) syscall_test( @@ -858,11 +890,13 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", + vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_stream_test", + vfs2 = "True", ) syscall_test( @@ -880,6 +914,7 @@ syscall_test( syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test", + vfs2 = "True", ) syscall_test( @@ -893,6 +928,7 @@ syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", + vfs2 = "True", ) syscall_test( @@ -916,6 +952,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:sticky_test", + vfs2 = "True", ) syscall_test( @@ -933,6 +970,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_file_range_test", + vfs2 = "True", ) syscall_test( @@ -989,6 +1027,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:tuntap_test", + vfs2 = "True", ) syscall_test( @@ -1072,6 +1111,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:proc_net_unix_test", + vfs2 = "True", ) syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 078e4a284..66a31cd28 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -748,9 +748,14 @@ cc_binary( linkstatic = 1, deps = [ ":file_base", + ":socket_test_util", "//test/util:cleanup", + "//test/util:eventfd_util", "//test/util:file_descriptor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", gtest, + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -805,6 +810,7 @@ cc_binary( "//test/util:save_util", "//test/util:temp_path", "//test/util:test_util", + "//test/util:thread_util", "//test/util:timer_util", ], ) @@ -937,6 +943,7 @@ cc_binary( "//test/util:eventfd_util", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", gtest, "//test/util:posix_error", @@ -974,6 +981,7 @@ cc_binary( "//test/util:thread_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", ], ) @@ -1323,6 +1331,7 @@ cc_binary( name = "packet_socket_raw_test", testonly = 1, srcs = ["packet_socket_raw.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -1799,9 +1808,10 @@ cc_binary( ) cc_binary( - name = "raw_socket_ipv4_test", + name = "raw_socket_test", testonly = 1, - srcs = ["raw_socket_ipv4.cc"], + srcs = ["raw_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -3400,6 +3410,7 @@ cc_binary( name = "tcp_socket_test", testonly = 1, srcs = ["tcp_socket.cc"], + defines = select_system(), linkstatic = 1, deps = [ ":socket_test_util", @@ -3546,11 +3557,15 @@ cc_library( hdrs = ["udp_socket_test_cases.h"], defines = select_system(), deps = [ + ":ip_socket_test_util", ":socket_test_util", ":unix_domain_socket_test_util", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", gtest, + "//test/util:file_descriptor", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 4dd302eed..1d0d584cd 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -153,6 +153,27 @@ TEST(DevTest, TTYExists) { EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); } +TEST(DevTest, OpenDevFuse) { + // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new + // device registration is complete. + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor() || !IsFUSEEnabled()); + + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); +} + +TEST(DevTest, ReadDevFuseWithoutMount) { + // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new + // device registration is complete. + SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_RDONLY)); + + std::vector<char> buf(1); + EXPECT_THAT(ReadFd(fd.get(), buf.data(), sizeof(buf)), + SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/epoll.cc b/test/syscalls/linux/epoll.cc index f57d38dc7..2101e5c9f 100644 --- a/test/syscalls/linux/epoll.cc +++ b/test/syscalls/linux/epoll.cc @@ -422,6 +422,28 @@ TEST(EpollTest, CloseFile) { SyscallSucceedsWithValue(0)); } +TEST(EpollTest, PipeReaderHupAfterWriterClosed) { + auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD()); + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + FileDescriptor rfd(pipefds[0]); + FileDescriptor wfd(pipefds[1]); + + ASSERT_NO_ERRNO(RegisterEpollFD(epollfd.get(), rfd.get(), 0, kMagicConstant)); + struct epoll_event result[kFDsPerEpoll]; + // Initially, rfd should not generate any events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(0)); + // Close the write end of the pipe. + wfd.reset(); + // rfd should now generate EPOLLHUP, which EPOLL_CTL_ADD unconditionally adds + // to the set of events of interest. + ASSERT_THAT(epoll_wait(epollfd.get(), result, kFDsPerEpoll, 0), + SyscallSucceedsWithValue(1)); + EXPECT_EQ(result[0].events, EPOLLHUP); + EXPECT_EQ(result[0].data.u64, kMagicConstant); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/eventfd.cc b/test/syscalls/linux/eventfd.cc index 548b05a64..dc794415e 100644 --- a/test/syscalls/linux/eventfd.cc +++ b/test/syscalls/linux/eventfd.cc @@ -100,20 +100,21 @@ TEST(EventfdTest, SmallRead) { ASSERT_THAT(read(efd.get(), &l, 4), SyscallFailsWithErrno(EINVAL)); } -TEST(EventfdTest, PreadIllegalSeek) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); - - uint64_t l = 0; - ASSERT_THAT(pread(efd.get(), &l, 4, 0), SyscallFailsWithErrno(ESPIPE)); +TEST(EventfdTest, IllegalSeek) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + EXPECT_THAT(lseek(efd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); } -TEST(EventfdTest, PwriteIllegalSeek) { - FileDescriptor efd = - ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); +TEST(EventfdTest, IllegalPread) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + int l; + EXPECT_THAT(pread(efd.get(), &l, sizeof(l), 0), + SyscallFailsWithErrno(ESPIPE)); +} - uint64_t l = 0; - ASSERT_THAT(pwrite(efd.get(), &l, 4, 0), SyscallFailsWithErrno(ESPIPE)); +TEST(EventfdTest, IllegalPwrite) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + EXPECT_THAT(pwrite(efd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); } TEST(EventfdTest, BigWrite) { diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc index e09afafe9..c5acfc794 100644 --- a/test/syscalls/linux/exec.cc +++ b/test/syscalls/linux/exec.cc @@ -553,7 +553,12 @@ TEST(ExecTest, SymlinkLimitRefreshedForInterpreter) { // Hold onto TempPath objects so they are not destructed prematurely. std::vector<TempPath> interpreter_symlinks; std::vector<TempPath> script_symlinks; - for (int i = 0; i < kLinuxMaxSymlinks; i++) { + // Replace both the interpreter and script paths with symlink chains of just + // over half the symlink limit each; this is the minimum required to test that + // the symlink limit applies separately to each traversal, while tolerating + // some symlinks in the resolution of (the original) interpreter_path and + // script_path. + for (int i = 0; i < (kLinuxMaxSymlinks / 2) + 1; i++) { interpreter_symlinks.push_back(ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(tmp_dir, interpreter_path))); interpreter_path = interpreter_symlinks[i].path(); @@ -679,18 +684,16 @@ TEST(ExecveatTest, UnshareFiles) { const FileDescriptor fd_closed_on_exec = ASSERT_NO_ERRNO_AND_VALUE(Open(tempFile.path(), O_RDONLY | O_CLOEXEC)); - pid_t child; - EXPECT_THAT(child = syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, - 0, 0, 0, 0), - SyscallSucceeds()); + ExecveArray argv = {"test"}; + ExecveArray envp; + std::string child_path = RunfilePath(kBasicWorkload); + pid_t child = + syscall(__NR_clone, SIGCHLD | CLONE_VFORK | CLONE_FILES, 0, 0, 0, 0); if (child == 0) { - ExecveArray argv = {"test"}; - ExecveArray envp; - ASSERT_THAT( - execve(RunfilePath(kBasicWorkload).c_str(), argv.get(), envp.get()), - SyscallSucceeds()); + execve(child_path.c_str(), argv.get(), envp.get()); _exit(1); } + ASSERT_THAT(child, SyscallSucceeds()); int status; ASSERT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds()); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index 7819f4ac3..cabc2b751 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -15,16 +15,27 @@ #include <errno.h> #include <fcntl.h> #include <signal.h> +#include <sys/eventfd.h> #include <sys/resource.h> +#include <sys/signalfd.h> +#include <sys/socket.h> #include <sys/stat.h> +#include <sys/timerfd.h> #include <syscall.h> #include <time.h> #include <unistd.h> +#include <ctime> + #include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" #include "test/syscalls/linux/file_base.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/cleanup.h" +#include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -70,6 +81,12 @@ TEST_F(AllocateTest, Fallocate) { ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds()); ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); EXPECT_EQ(buf.st_size, 40); + + // Given length 0 should fail with EINVAL. + ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 50, 0), + SyscallFailsWithErrno(EINVAL)); + ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds()); + EXPECT_EQ(buf.st_size, 40); } TEST_F(AllocateTest, FallocateInvalid) { @@ -136,6 +153,34 @@ TEST_F(AllocateTest, FallocateRlimit) { ASSERT_THAT(sigprocmask(SIG_UNBLOCK, &new_mask, nullptr), SyscallSucceeds()); } +TEST_F(AllocateTest, FallocateOtherFDs) { + int fd; + ASSERT_THAT(fd = timerfd_create(CLOCK_MONOTONIC, 0), SyscallSucceeds()); + auto timer_fd = FileDescriptor(fd); + EXPECT_THAT(fallocate(timer_fd.get(), 0, 0, 10), + SyscallFailsWithErrno(ENODEV)); + + sigset_t mask; + sigemptyset(&mask); + ASSERT_THAT(fd = signalfd(-1, &mask, 0), SyscallSucceeds()); + auto sfd = FileDescriptor(fd); + EXPECT_THAT(fallocate(sfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + auto efd = + ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, EFD_NONBLOCK | EFD_SEMAPHORE)); + EXPECT_THAT(fallocate(efd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + auto sockfd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + EXPECT_THAT(fallocate(sockfd.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + int socks[2]; + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), + SyscallSucceeds()); + auto sock0 = FileDescriptor(socks[0]); + auto sock1 = FileDescriptor(socks[1]); + EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc index 25bef2522..34016d4bd 100644 --- a/test/syscalls/linux/fcntl.cc +++ b/test/syscalls/linux/fcntl.cc @@ -18,7 +18,10 @@ #include <syscall.h> #include <unistd.h> +#include <iostream> +#include <list> #include <string> +#include <vector> #include "gtest/gtest.h" #include "absl/base/macros.h" @@ -37,6 +40,7 @@ #include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" +#include "test/util/thread_util.h" #include "test/util/timer_util.h" ABSL_FLAG(std::string, child_setlock_on, "", @@ -953,15 +957,18 @@ TEST(FcntlTest, DupAfterO_ASYNC) { EXPECT_EQ(after & O_ASYNC, O_ASYNC); } -TEST(FcntlTest, GetOwn) { +TEST(FcntlTest, GetOwnNone) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), 0); + // Use the raw syscall because the glibc wrapper may convert F_{GET,SET}OWN + // into F_{GET,SET}OWN_EX. + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); MaybeSave(); } -TEST(FcntlTest, GetOwnEx) { +TEST(FcntlTest, GetOwnExNone) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); @@ -970,6 +977,100 @@ TEST(FcntlTest, GetOwnEx) { SyscallSucceedsWithValue(0)); } +TEST(FcntlTest, SetOwnInvalidPid) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 12345678), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnInvalidPgrp) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -12345678), + SyscallFailsWithErrno(ESRCH)); +} + +TEST(FcntlTest, SetOwnPid) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(pid)); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnPgrp) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pgid; + EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); + + // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the + // negative return value as an error, converting the return value to -1 and + // setting errno accordingly. + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, F_OWNER_PGRP); + EXPECT_EQ(got_owner.pid, pgid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnUnset) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + // Set and unset pid. + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + + // Set and unset pgid. + pid_t pgid; + EXPECT_THAT(pgid = getpgrp(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, -pgid), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, 0), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + MaybeSave(); +} + +// F_SETOWN flips the sign of negative values, an operation that is guarded +// against overflow. +TEST(FcntlTest, SetOwnOverflow) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, INT_MIN), + SyscallFailsWithErrno(EINVAL)); +} + TEST(FcntlTest, SetOwnExInvalidType) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); @@ -1025,9 +1126,10 @@ TEST(FcntlTest, SetOwnExTid) { EXPECT_THAT(owner.pid = syscall(__NR_gettid), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(owner.pid)); MaybeSave(); } @@ -1040,9 +1142,10 @@ TEST(FcntlTest, SetOwnExPid) { EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), owner.pid); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(owner.pid)); MaybeSave(); } @@ -1050,18 +1153,54 @@ TEST(FcntlTest, SetOwnExPgrp) { FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + f_owner_ex set_owner = {}; + set_owner.type = F_OWNER_PGRP; + EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); + + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), + SyscallSucceedsWithValue(0)); + + // Verify with F_GETOWN_EX; using F_GETOWN on Linux may incorrectly treat the + // negative return value as an error, converting the return value to -1 and + // setting errno accordingly. + f_owner_ex got_owner = {}; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(got_owner.type, set_owner.type); + EXPECT_EQ(got_owner.pid, set_owner.pid); + MaybeSave(); +} + +TEST(FcntlTest, SetOwnExUnset) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + // Set and unset pid. f_owner_ex owner = {}; + owner.type = F_OWNER_PID; + EXPECT_THAT(owner.pid = getpid(), SyscallSucceeds()); + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + owner.pid = 0; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); + + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); + + // Set and unset pgid. owner.type = F_OWNER_PGRP; EXPECT_THAT(owner.pid = getpgrp(), SyscallSucceeds()); - ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), - SyscallSucceeds()); - - // NOTE(igudger): I don't understand why, but this is flaky on Linux. - // GetOwnExPgrp (below) does not have this issue. - SKIP_IF(!IsRunningOnGvisor()); + SyscallSucceedsWithValue(0)); + owner.pid = 0; + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &owner), + SyscallSucceedsWithValue(0)); - EXPECT_EQ(syscall(__NR_fcntl, s.get(), F_GETOWN), -owner.pid); + EXPECT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN), + SyscallSucceedsWithValue(0)); MaybeSave(); } @@ -1074,7 +1213,7 @@ TEST(FcntlTest, GetOwnExTid) { EXPECT_THAT(set_owner.pid = syscall(__NR_gettid), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), @@ -1092,7 +1231,7 @@ TEST(FcntlTest, GetOwnExPid) { EXPECT_THAT(set_owner.pid = getpid(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), @@ -1110,7 +1249,7 @@ TEST(FcntlTest, GetOwnExPgrp) { EXPECT_THAT(set_owner.pid = getpgrp(), SyscallSucceeds()); ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN_EX, &set_owner), - SyscallSucceeds()); + SyscallSucceedsWithValue(0)); f_owner_ex got_owner = {}; ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_GETOWN_EX, &got_owner), @@ -1119,6 +1258,45 @@ TEST(FcntlTest, GetOwnExPgrp) { EXPECT_EQ(got_owner.pid, set_owner.pid); } +// Make sure that making multiple concurrent changes to async signal generation +// does not cause any race issues. +TEST(FcntlTest, SetFlSetOwnDoNotRace) { + FileDescriptor s = ASSERT_NO_ERRNO_AND_VALUE( + Socket(AF_UNIX, SOCK_SEQPACKET | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); + + pid_t pid; + EXPECT_THAT(pid = getpid(), SyscallSucceeds()); + + constexpr absl::Duration runtime = absl::Milliseconds(300); + auto setAsync = [&s, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, O_ASYNC), + SyscallSucceeds()); + sched_yield(); + } + }; + auto resetAsync = [&s, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETFL, 0), SyscallSucceeds()); + sched_yield(); + } + }; + auto setOwn = [&s, &pid, &runtime] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(syscall(__NR_fcntl, s.get(), F_SETOWN, pid), + SyscallSucceeds()); + sched_yield(); + } + }; + + std::list<ScopedThread> threads; + for (int i = 0; i < 10; i++) { + threads.emplace_back(setAsync); + threads.emplace_back(resetAsync); + threads.emplace_back(setOwn); + } +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc index 40c80a6e1..90b1f0508 100644 --- a/test/syscalls/linux/futex.cc +++ b/test/syscalls/linux/futex.cc @@ -18,6 +18,7 @@ #include <sys/syscall.h> #include <sys/time.h> #include <sys/types.h> +#include <syscall.h> #include <unistd.h> #include <algorithm> @@ -737,6 +738,97 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) { } } +int get_robust_list(int pid, struct robust_list_head** head_ptr, + size_t* len_ptr) { + return syscall(__NR_get_robust_list, pid, head_ptr, len_ptr); +} + +int set_robust_list(struct robust_list_head* head, size_t len) { + return syscall(__NR_set_robust_list, head, len); +} + +TEST(RobustFutexTest, BasicSetGet) { + struct robust_list_head hd = {}; + struct robust_list_head* hd_ptr = &hd; + + // Set! + EXPECT_THAT(set_robust_list(hd_ptr, sizeof(hd)), SyscallSucceedsWithValue(0)); + + // Get! + struct robust_list_head* new_hd_ptr = hd_ptr; + size_t len; + EXPECT_THAT(get_robust_list(0, &new_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(new_hd_ptr, hd_ptr); + EXPECT_EQ(len, sizeof(hd)); +} + +TEST(RobustFutexTest, GetFromOtherTid) { + // Get the current tid and list head. + pid_t tid = gettid(); + struct robust_list_head* hd_ptr = {}; + size_t len; + EXPECT_THAT(get_robust_list(0, &hd_ptr, &len), SyscallSucceedsWithValue(0)); + + // Create a new thread. + ScopedThread t([&] { + // Current tid list head should be different from parent tid. + struct robust_list_head* got_hd_ptr = {}; + EXPECT_THAT(get_robust_list(0, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_NE(hd_ptr, got_hd_ptr); + + // Get the parent list head by passing its tid. + EXPECT_THAT(get_robust_list(tid, &got_hd_ptr, &len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(hd_ptr, got_hd_ptr); + }); + + // Wait for thread. + t.Join(); +} + +TEST(RobustFutexTest, InvalidSize) { + struct robust_list_head* hd = {}; + EXPECT_THAT(set_robust_list(hd, sizeof(*hd) + 1), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(RobustFutexTest, PthreadMutexAttr) { + constexpr int kNumMutexes = 3; + + // Create a bunch of robust mutexes. + pthread_mutexattr_t attrs[kNumMutexes]; + pthread_mutex_t mtxs[kNumMutexes]; + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutexattr_init(&attrs[i]) == 0); + TEST_PCHECK(pthread_mutexattr_setrobust(&attrs[i], PTHREAD_MUTEX_ROBUST) == + 0); + TEST_PCHECK(pthread_mutex_init(&mtxs[i], &attrs[i]) == 0); + } + + // Start thread to lock the mutexes and then exit. + ScopedThread t([&] { + for (int i = 0; i < kNumMutexes; i++) { + TEST_PCHECK(pthread_mutex_lock(&mtxs[i]) == 0); + } + pthread_exit(NULL); + }); + + // Wait for thread. + t.Join(); + + // Now try to take the mutexes. + for (int i = 0; i < kNumMutexes; i++) { + // Should get EOWNERDEAD. + EXPECT_EQ(pthread_mutex_lock(&mtxs[i]), EOWNERDEAD); + // Make the mutex consistent. + EXPECT_EQ(pthread_mutex_consistent(&mtxs[i]), 0); + // Unlock. + EXPECT_EQ(pthread_mutex_unlock(&mtxs[i]), 0); + } +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index b147d6181..b040cdcf7 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -32,6 +32,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "test/util/eventfd_util.h" @@ -393,7 +394,7 @@ TYPED_TEST(GetdentsTest, ProcSelfFd) { // Make the buffer very small since we want to iterate. typename TestFixture::DirentBufferType dirents( 2 * sizeof(typename TestFixture::LinuxDirentType)); - std::unordered_set<int> prev_fds; + absl::node_hash_set<int> prev_fds; while (true) { dirents.Reset(); int rv; diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 1d1a7171d..220874aeb 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -29,6 +29,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/epoll_util.h" @@ -332,9 +333,27 @@ PosixErrorOr<int> InotifyAddWatch(int fd, const std::string& path, return wd; } -TEST(Inotify, InotifyFdNotWritable) { +TEST(Inotify, IllegalSeek) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); - EXPECT_THAT(write(fd.get(), "x", 1), SyscallFailsWithErrno(EBADF)); + EXPECT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalPread) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + int val; + EXPECT_THAT(pread(fd.get(), &val, sizeof(val), 0), + SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalPwrite) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + EXPECT_THAT(pwrite(fd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); +} + +TEST(Inotify, IllegalWrite) { + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(0)); + int val = 0; + EXPECT_THAT(write(fd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EBADF)); } TEST(Inotify, InitFlags) { @@ -1305,7 +1324,7 @@ TEST(Inotify, UtimesGeneratesAttribEvent) { const int wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - struct timeval times[2] = {{1, 0}, {2, 0}}; + const struct timeval times[2] = {{1, 0}, {2, 0}}; EXPECT_THAT(futimes(file1_fd.get(), times), SyscallSucceeds()); const std::vector<Event> events = @@ -1484,20 +1503,26 @@ TEST(Inotify, DuplicateWatchReturnsSameWatchDescriptor) { TEST(Inotify, UnmatchedEventsAreDiscarded) { const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - const TempPath file1 = + TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file1.path(), IN_ACCESS)); - const FileDescriptor file1_fd = + FileDescriptor file1_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); - const std::vector<Event> events = - ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); // We only asked for access events, the open event should be discarded. ASSERT_THAT(events, Are({})); + + // IN_IGNORED events are always generated, regardless of the mask. + file1_fd.reset(); + file1.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); + ASSERT_THAT(events, Are({Event(IN_IGNORED, wd)})); } TEST(Inotify, AddWatchWithInvalidEventMaskFails) { @@ -1839,9 +1864,7 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateFileWith(dir.path(), "123", TempPath::kDefaultFileMode)); - - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); @@ -1854,7 +1877,7 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { int val = 0; ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); - const std::vector<Event> events = + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); EXPECT_THAT(events, Are({ Event(IN_ATTRIB, file_wd), @@ -1864,6 +1887,15 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { Event(IN_MODIFY, dir_wd, Basename(file.path())), Event(IN_MODIFY, file_wd), })); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_CLOSE_WRITE, dir_wd, Basename(file.path())), + Event(IN_CLOSE_WRITE, file_wd), + Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + })); } // Watches created with IN_EXCL_UNLINK will stop emitting events on fds for @@ -1880,13 +1912,14 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); // Unlink the child, which should cause further operations on the open file // descriptor to be ignored. @@ -1894,14 +1927,28 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { int val = 0; ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); - const std::vector<Event> events = + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({Event(IN_DELETE, wd, Basename(file.path()))})); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + })); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({ + Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + })); } // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { + // TODO(gvisor.dev/issue/1624): This test fails on VFS1. Remove once VFS1 is + // deleted. + SKIP_IF(IsRunningWithVFS1()); + const DisableSave ds; const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -1911,20 +1958,29 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const FileDescriptor fd = + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(dirPath.c_str(), O_RDONLY | O_DIRECTORY)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + const int parent_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( inotify_fd.get(), parent.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int self_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); // Unlink the dir, and then close the open fd. ASSERT_THAT(rmdir(dirPath.c_str()), SyscallSucceeds()); dir.reset(); - const std::vector<Event> events = + std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); // No close event should appear. ASSERT_THAT(events, - Are({Event(IN_DELETE | IN_ISDIR, wd, Basename(dirPath))})); + Are({Event(IN_DELETE | IN_ISDIR, parent_wd, Basename(dirPath))})); + + fd.reset(); + events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + ASSERT_THAT(events, Are({ + Event(IN_DELETE_SELF, self_wd), + Event(IN_IGNORED, self_wd), + })); } // If "dir/child" and "dir/child2" are links to the same file, and "dir/child" @@ -1988,10 +2044,6 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR)); - const FileDescriptor inotify_fd = - ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); - const int wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( - inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); // NOTE(b/157163751): Create another link before unlinking. This is needed for // the gofer filesystem in gVisor, where open fds will not work once the link @@ -2005,6 +2057,13 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { ASSERT_THAT(rc, SyscallSucceeds()); } + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), dir.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( + inotify_fd.get(), file.path(), IN_ALL_EVENTS | IN_EXCL_UNLINK)); + // Even after unlinking, inode-level operations will trigger events regardless // of IN_EXCL_UNLINK. ASSERT_THAT(unlink(file.path().c_str()), SyscallSucceeds()); @@ -2014,20 +2073,306 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); EXPECT_THAT(events, Are({ - Event(IN_DELETE, wd, Basename(file.path())), - Event(IN_MODIFY, wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), })); - struct timeval times[2] = {{1, 0}, {2, 0}}; + const struct timeval times[2] = {{1, 0}, {2, 0}}; ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds()); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd, Basename(file.path()))})); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + })); // S/R is disabled on this entire test due to behavior with unlink; it must // also be disabled after this point because of fchmod. ASSERT_THAT(fchmod(fd.get(), 0777), SyscallSucceeds()); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({Event(IN_ATTRIB, wd, Basename(file.path()))})); + EXPECT_THAT(events, Are({ + Event(IN_ATTRIB, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + })); +} + +TEST(Inotify, OneShot) { + // TODO(gvisor.dev/issue/1624): IN_ONESHOT not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_MODIFY | IN_ONESHOT)); + + // Open an fd, write to it, and then close it. + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + ASSERT_THAT(write(fd.get(), "x", 1), SyscallSucceedsWithValue(1)); + fd.reset(); + + // We should get a single event followed by IN_IGNORED indicating removal + // of the one-shot watch. Prior activity (i.e. open) that is not in the mask + // should not trigger removal, and activity after removal (i.e. close) should + // not generate events. + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({ + Event(IN_MODIFY, wd), + Event(IN_IGNORED, wd), + })); + + // The watch should already have been removed. + EXPECT_THAT(inotify_rm_watch(inotify_fd.get(), wd), + SyscallFailsWithErrno(EINVAL)); +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when inotify instances and watch targets are concurrently being +// destroyed. +TEST(InotifyTest, InotifyAndTargetDestructionDoNotDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + // A file descriptor protected by a mutex. This ensures that while a + // descriptor is in use, it cannot be closed and reused for a different file + // description. + struct atomic_fd { + int fd; + absl::Mutex mu; + }; + + // Set up initial inotify instances. + constexpr int num_fds = 3; + std::vector<atomic_fd> fds(num_fds); + for (int i = 0; i < num_fds; i++) { + int fd; + ASSERT_THAT(fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + fds[i].fd = fd; + } + + // Set up initial watch targets. + std::vector<std::string> paths; + for (int i = 0; i < 3; i++) { + paths.push_back(NewTempAbsPath()); + ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + + constexpr absl::Duration runtime = absl::Seconds(4); + + // Constantly replace each inotify instance with a new one. + auto replace_fds = [&] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& afd : fds) { + int new_fd; + ASSERT_THAT(new_fd = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + absl::MutexLock l(&afd.mu); + ASSERT_THAT(close(afd.fd), SyscallSucceeds()); + afd.fd = new_fd; + for (auto& p : paths) { + // inotify_add_watch may fail if the file at p was deleted. + ASSERT_THAT(inotify_add_watch(afd.fd, p.c_str(), IN_ALL_EVENTS), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + } + } + sched_yield(); + } + }; + + std::list<ScopedThread> ts; + for (int i = 0; i < 3; i++) { + ts.emplace_back(replace_fds); + } + + // Constantly replace each watch target with a new one. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& p : paths) { + ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds()); + ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + sched_yield(); + } +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when adding/removing watches occurs concurrently with the +// removal of their targets. +TEST(InotifyTest, AddRemoveUnlinkDoNotDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + // Set up inotify instances. + constexpr int num_fds = 3; + std::vector<int> fds(num_fds); + for (int i = 0; i < num_fds; i++) { + ASSERT_THAT(fds[i] = inotify_init1(IN_NONBLOCK), SyscallSucceeds()); + } + + // Set up initial watch targets. + std::vector<std::string> paths; + for (int i = 0; i < 3; i++) { + paths.push_back(NewTempAbsPath()); + ASSERT_THAT(mknod(paths[i].c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + + constexpr absl::Duration runtime = absl::Seconds(1); + + // Constantly add/remove watches for each inotify instance/watch target pair. + auto add_remove_watches = [&] { + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (int fd : fds) { + for (auto& p : paths) { + // Do not assert on inotify_add_watch and inotify_rm_watch. They may + // fail if the file at p was deleted. inotify_add_watch may also fail + // if another thread beat us to adding a watch. + const int wd = inotify_add_watch(fd, p.c_str(), IN_ALL_EVENTS); + if (wd > 0) { + inotify_rm_watch(fd, wd); + } + } + } + sched_yield(); + } + }; + + std::list<ScopedThread> ts; + for (int i = 0; i < 15; i++) { + ts.emplace_back(add_remove_watches); + } + + // Constantly replace each watch target with a new one. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + for (auto& p : paths) { + ASSERT_THAT(unlink(p.c_str()), SyscallSucceeds()); + ASSERT_THAT(mknod(p.c_str(), S_IFREG | 0600, 0), SyscallSucceeds()); + } + sched_yield(); + } +} + +// This test helps verify that the lock order of filesystem and inotify locks +// is respected when many inotify events and filesystem operations occur +// simultaneously. +TEST(InotifyTest, NotifyNoDeadlock_NoRandomSave) { + const DisableSave ds; // Too many syscalls. + + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const std::string dir = parent.path(); + + // mu protects file, which will change on rename. + absl::Mutex mu; + std::string file = NewTempAbsPathInDir(dir); + ASSERT_THAT(mknod(file.c_str(), 0644 | S_IFREG, 0), SyscallSucceeds()); + + const absl::Duration runtime = absl::Milliseconds(300); + + // Add/remove watches on dir and file. + ScopedThread add_remove_watches([&] { + const FileDescriptor ifd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + int dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS)); + int file_wd; + { + absl::ReaderMutexLock l(&mu); + file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS)); + } + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + ASSERT_THAT(inotify_rm_watch(ifd.get(), file_wd), SyscallSucceeds()); + ASSERT_THAT(inotify_rm_watch(ifd.get(), dir_wd), SyscallSucceeds()); + dir_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), dir, IN_ALL_EVENTS)); + { + absl::ReaderMutexLock l(&mu); + file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(ifd.get(), file, IN_ALL_EVENTS)); + } + sched_yield(); + } + }); + + // Modify attributes on dir and file. + ScopedThread stats([&] { + int fd, dir_fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds()); + } + ASSERT_THAT(dir_fd = open(dir.c_str(), O_RDONLY | O_DIRECTORY), + SyscallSucceeds()); + const struct timeval times[2] = {{1, 0}, {2, 0}}; + + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + { + absl::ReaderMutexLock l(&mu); + EXPECT_THAT(utimes(file.c_str(), times), SyscallSucceeds()); + } + EXPECT_THAT(futimes(fd, times), SyscallSucceeds()); + EXPECT_THAT(utimes(dir.c_str(), times), SyscallSucceeds()); + EXPECT_THAT(futimes(dir_fd, times), SyscallSucceeds()); + sched_yield(); + } + }); + + // Modify extended attributes on dir and file. + ScopedThread xattrs([&] { + // TODO(gvisor.dev/issue/1636): Support extended attributes in runsc gofer. + if (!IsRunningOnGvisor()) { + int fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDONLY), SyscallSucceeds()); + } + + const char* name = "user.test"; + int val = 123; + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT( + setxattr(file.c_str(), name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + ASSERT_THAT(removexattr(file.c_str(), name), SyscallSucceeds()); + } + + ASSERT_THAT(fsetxattr(fd, name, &val, sizeof(val), /*flags=*/0), + SyscallSucceeds()); + ASSERT_THAT(fremovexattr(fd, name), SyscallSucceeds()); + sched_yield(); + } + } + }); + + // Read and write file's contents. Read and write dir's entries. + ScopedThread read_write([&] { + int fd; + { + absl::ReaderMutexLock l(&mu); + ASSERT_THAT(fd = open(file.c_str(), O_RDWR), SyscallSucceeds()); + } + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + int val = 123; + ASSERT_THAT(write(fd, &val, sizeof(val)), SyscallSucceeds()); + ASSERT_THAT(read(fd, &val, sizeof(val)), SyscallSucceeds()); + TempPath new_file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir)); + ASSERT_NO_ERRNO(ListDir(dir, false)); + new_file.reset(); + sched_yield(); + } + }); + + // Rename file. + for (auto start = absl::Now(); absl::Now() - start < runtime;) { + const std::string new_path = NewTempAbsPathInDir(dir); + { + absl::WriterMutexLock l(&mu); + ASSERT_THAT(rename(file.c_str(), new_path.c_str()), SyscallSucceeds()); + file = new_path; + } + sched_yield(); + } } } // namespace diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc index e74fa2ed5..544681168 100644 --- a/test/syscalls/linux/link.cc +++ b/test/syscalls/linux/link.cc @@ -79,8 +79,13 @@ TEST(LinkTest, PermissionDenied) { // Make the file "unsafe" to link by making it only readable, but not // writable. - const auto oldfile = + const auto unwriteable_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); + const std::string special_path = NewTempAbsPath(); + ASSERT_THAT(mkfifo(special_path.c_str(), 0666), SyscallSucceeds()); + const auto setuid_file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0666 | S_ISUID)); + const std::string newname = NewTempAbsPath(); // Do setuid in a separate thread so that after finishing this test, the @@ -97,8 +102,14 @@ TEST(LinkTest, PermissionDenied) { EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)), SyscallSucceeds()); - EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()), + EXPECT_THAT(link(unwriteable_file.path().c_str(), newname.c_str()), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(link(special_path.c_str(), newname.c_str()), SyscallFailsWithErrno(EPERM)); + if (!IsRunningWithVFS1()) { + EXPECT_THAT(link(setuid_file.path().c_str(), newname.c_str()), + SyscallFailsWithErrno(EPERM)); + } }); } diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index 4c45766c7..05dfb375a 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -15,6 +15,7 @@ #include <errno.h> #include <fcntl.h> #include <sys/stat.h> +#include <sys/types.h> #include <sys/un.h> #include <unistd.h> @@ -39,7 +40,28 @@ TEST(MknodTest, RegularFile) { EXPECT_THAT(mknod(node1.c_str(), 0, 0), SyscallSucceeds()); } -TEST(MknodTest, MknodAtRegularFile) { +TEST(MknodTest, RegularFilePermissions) { + const std::string node = NewTempAbsPath(); + mode_t newUmask = 0077; + umask(newUmask); + + // Attempt to open file with mode 0777. Not specifying file type should create + // a regualar file. + mode_t perms = S_IRWXU | S_IRWXG | S_IRWXO; + EXPECT_THAT(mknod(node.c_str(), perms, 0), SyscallSucceeds()); + + // In the absence of a default ACL, the permissions of the created node are + // (mode & ~umask). -- mknod(2) + mode_t wantPerms = perms & ~newUmask; + struct stat st; + ASSERT_THAT(stat(node.c_str(), &st), SyscallSucceeds()); + ASSERT_EQ(st.st_mode & 0777, wantPerms); + + // "Zero file type is equivalent to type S_IFREG." - mknod(2) + ASSERT_EQ(st.st_mode & S_IFMT, S_IFREG); +} + +TEST(MknodTest, MknodAtFIFO) { const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const std::string fifo_relpath = NewTempRelPath(); const std::string fifo = JoinPath(dir.path(), fifo_relpath); @@ -72,7 +94,7 @@ TEST(MknodTest, MknodOnExistingPathFails) { TEST(MknodTest, UnimplementedTypesReturnError) { const std::string path = NewTempAbsPath(); - if (IsRunningOnGvisor()) { + if (IsRunningWithVFS1()) { ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0), SyscallFailsWithErrno(EOPNOTSUPP)); } diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index a3e9745cf..97e8d0f7e 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -321,6 +321,28 @@ TEST(MountTest, RenameRemoveMountPoint) { ASSERT_THAT(rmdir(dir.path().c_str()), SyscallFailsWithErrno(EBUSY)); } +TEST(MountTest, MountFuseFilesystemNoDevice) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(MountTest, MountFuseFilesystem) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); + std::string mopts = "fd=" + std::to_string(fd.get()); + + auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto const mount = + ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index 670c0284b..bf350946b 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -235,7 +235,7 @@ TEST_F(OpenTest, AppendOnly) { ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR | O_APPEND)); EXPECT_THAT(lseek(fd2.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); - // Then try to write to the first file and make sure the bytes are appended. + // Then try to write to the first fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd1.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -247,7 +247,7 @@ TEST_F(OpenTest, AppendOnly) { EXPECT_THAT(lseek(fd1.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(kBufSize * 2)); - // Then try to write to the second file and make sure the bytes are appended. + // Then try to write to the second fd and make sure the bytes are appended. EXPECT_THAT(WriteFd(fd2.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(buf.size())); @@ -439,6 +439,12 @@ TEST_F(OpenTest, CanTruncateWithStrangePermissions) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(OpenTest, OpenNonDirectoryWithTrailingSlash) { + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const std::string bad_path = file.path() + "/"; + EXPECT_THAT(open(bad_path.c_str(), O_RDONLY), SyscallFailsWithErrno(ENOTDIR)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index 5ac68feb4..861617ff7 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -188,11 +188,12 @@ void ReceiveMessage(int sock, int ifindex) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, ifindex); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -233,7 +234,7 @@ TEST_P(CookedPacketTest, Receive) { // Send via a packet socket. TEST_P(CookedPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -343,7 +344,7 @@ TEST_P(CookedPacketTest, BindReceive) { } // Double Bind socket. -TEST_P(CookedPacketTest, DoubleBind) { +TEST_P(CookedPacketTest, DoubleBindSucceeds) { struct sockaddr_ll bind_addr = {}; bind_addr.sll_family = AF_PACKET; bind_addr.sll_protocol = htons(GetParam()); @@ -354,12 +355,11 @@ TEST_P(CookedPacketTest, DoubleBind) { SyscallSucceeds()); // Binding socket again should fail. - ASSERT_THAT( - bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - // Linux 4.09 returns EINVAL here, but some time before 4.19 it switched - // to EADDRINUSE. - AnyOf(SyscallFailsWithErrno(EADDRINUSE), SyscallFailsWithErrno(EINVAL))); + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + // Linux 4.09 returns EINVAL here, but some time before 4.19 it + // switched to EADDRINUSE. + SyscallSucceeds()); } // Bind and verify we do not receive data on interface which is not bound @@ -417,6 +417,122 @@ TEST_P(CookedPacketTest, BindDrop) { EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); } +// Verify that we receive outbound packets. This test requires at least one +// non loopback interface so that we can actually capture an outgoing packet. +TEST_P(CookedPacketTest, ReceiveOutbound) { + // Only ETH_P_ALL sockets can receive outbound packets on linux. + SKIP_IF(GetParam() != ETH_P_ALL); + + // Let's use a simple IP payload: a UDP datagram. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ifaddrs* if_addr_list = nullptr; + auto cleanup = Cleanup([&if_addr_list]() { freeifaddrs(if_addr_list); }); + + ASSERT_THAT(getifaddrs(&if_addr_list), SyscallSucceeds()); + + // Get interface other than loopback. + struct ifreq ifr = {}; + for (struct ifaddrs* i = if_addr_list; i; i = i->ifa_next) { + if (strcmp(i->ifa_name, "lo") != 0) { + strncpy(ifr.ifr_name, i->ifa_name, sizeof(ifr.ifr_name)); + break; + } + } + + // Skip if no interface is available other than loopback. + if (strlen(ifr.ifr_name) == 0) { + GTEST_SKIP(); + } + + // Get interface index and name. + EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_NE(ifr.ifr_ifindex, 0); + int ifindex = ifr.ifr_ifindex; + + constexpr int kMACSize = 6; + char hwaddr[kMACSize]; + // Get interface address. + ASSERT_THAT(ioctl(socket_, SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + ASSERT_THAT(ifr.ifr_hwaddr.sa_family, + AnyOf(Eq(ARPHRD_NONE), Eq(ARPHRD_ETHER))); + memcpy(hwaddr, ifr.ifr_hwaddr.sa_data, kMACSize); + + // Just send it to the google dns server 8.8.8.8. It's UDP we don't care + // if it actually gets to the DNS Server we just want to see that we receive + // it on our AF_PACKET socket. + // + // NOTE: We just want to pick an IP that is non-local to avoid having to + // handle ARP as this should cause the UDP packet to be sent to the default + // gateway configured for the system under test. Otherwise the only packet we + // will see is the ARP query unless we picked an IP which will actually + // resolve. The test is a bit brittle but this was the best compromise for + // now. + struct sockaddr_in dest = {}; + ASSERT_EQ(inet_pton(AF_INET, "8.8.8.8", &dest.sin_addr.s_addr), 1); + dest.sin_family = AF_INET; + dest.sin_port = kPort; + EXPECT_THAT(sendto(udp_sock.get(), kMessage, sizeof(kMessage), 0, + reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), + SyscallSucceedsWithValue(sizeof(kMessage))); + + // Wait and make sure the socket receives the data. + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(1)); + + // Now read and check that the packet is the one we just sent. + // Read and verify the data. + constexpr size_t packet_size = + sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kMessage); + char buf[64]; + struct sockaddr_ll src = {}; + socklen_t src_len = sizeof(src); + ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_len), + SyscallSucceedsWithValue(packet_size)); + + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but since commit b2cf86e1563e33a14a1c69b3e508d15dc12f804c returns + // sizeof(sockaddr_ll). + ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); + + // Verify the source address. + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, ifindex); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + EXPECT_EQ(src.sll_pkttype, PACKET_OUTGOING); + // Verify the link address of the interface matches that of the non + // non loopback interface address we stored above. + for (int i = 0; i < src.sll_halen; i++) { + EXPECT_EQ(src.sll_addr[i], hwaddr[i]); + } + + // Verify the IP header. + struct iphdr ip = {}; + memcpy(&ip, buf, sizeof(ip)); + EXPECT_EQ(ip.ihl, 5); + EXPECT_EQ(ip.version, 4); + EXPECT_EQ(ip.tot_len, htons(packet_size)); + EXPECT_EQ(ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ip.daddr, dest.sin_addr.s_addr); + EXPECT_NE(ip.saddr, htonl(INADDR_LOOPBACK)); + + // Verify the UDP header. + struct udphdr udp = {}; + memcpy(&udp, buf + sizeof(iphdr), sizeof(udp)); + EXPECT_EQ(udp.dest, kPort); + EXPECT_EQ(udp.len, htons(sizeof(udphdr) + sizeof(kMessage))); + + // Verify the payload. + char* payload = reinterpret_cast<char*>(buf + sizeof(iphdr) + sizeof(udphdr)); + EXPECT_EQ(strncmp(payload, kMessage, sizeof(kMessage)), 0); +} + // Bind with invalid address. TEST_P(CookedPacketTest, BindFail) { // Null address. diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index d258d353c..a11a03415 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,6 +14,9 @@ #include <arpa/inet.h> #include <linux/capability.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <linux/if_arp.h> #include <linux/if_packet.h> #include <net/ethernet.h> @@ -97,7 +100,7 @@ class RawPacketTest : public ::testing::TestWithParam<int> { int GetLoopbackIndex(); // The socket used for both reading and writing. - int socket_; + int s_; }; void RawPacketTest::SetUp() { @@ -108,34 +111,58 @@ void RawPacketTest::SetUp() { } if (!IsRunningOnGvisor()) { + // Ensure that looped back packets aren't rejected by the kernel. FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY)); + Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDWR)); FileDescriptor routeLocalnet = ASSERT_NO_ERRNO_AND_VALUE( - Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDONLY)); + Open("/proc/sys/net/ipv4/conf/lo/route_localnet", O_RDWR)); char enabled; ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); + if (enabled != '1') { + enabled = '1'; + ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(write(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(lseek(acceptLocal.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(read(acceptLocal.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), SyscallSucceedsWithValue(1)); - ASSERT_EQ(enabled, '1'); + if (enabled != '1') { + enabled = '1'; + ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(write(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_THAT(lseek(routeLocalnet.get(), 0, SEEK_SET), + SyscallSucceedsWithValue(0)); + ASSERT_THAT(read(routeLocalnet.get(), &enabled, 1), + SyscallSucceedsWithValue(1)); + ASSERT_EQ(enabled, '1'); + } } - ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), + ASSERT_THAT(s_ = socket(AF_PACKET, SOCK_RAW, htons(GetParam())), SyscallSucceeds()); } void RawPacketTest::TearDown() { // TearDown will be run even if we skip the test. if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - EXPECT_THAT(close(socket_), SyscallSucceeds()); + EXPECT_THAT(close(s_), SyscallSucceeds()); } } int RawPacketTest::GetLoopbackIndex() { struct ifreq ifr; snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); - EXPECT_THAT(ioctl(socket_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); + EXPECT_THAT(ioctl(s_, SIOCGIFINDEX, &ifr), SyscallSucceeds()); EXPECT_NE(ifr.ifr_ifindex, 0); return ifr.ifr_ifindex; } @@ -149,7 +176,7 @@ TEST_P(RawPacketTest, Receive) { // Wait for the socket to become readable. struct pollfd pfd = {}; - pfd.fd = socket_; + pfd.fd = s_; pfd.events = POLLIN; EXPECT_THAT(RetryEINTR(poll)(&pfd, 1, 2000), SyscallSucceedsWithValue(1)); @@ -159,7 +186,7 @@ TEST_P(RawPacketTest, Receive) { char buf[64]; struct sockaddr_ll src = {}; socklen_t src_len = sizeof(src); - ASSERT_THAT(recvfrom(socket_, buf, sizeof(buf), 0, + ASSERT_THAT(recvfrom(s_, buf, sizeof(buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_len), SyscallSucceedsWithValue(packet_size)); // sockaddr_ll ends with an 8 byte physical address field, but ethernet @@ -168,11 +195,12 @@ TEST_P(RawPacketTest, Receive) { // sizeof(sockaddr_ll). ASSERT_THAT(src_len, AnyOf(Eq(sizeof(src)), Eq(sizeof(src) - 2))); - // TODO(b/129292371): Verify protocol once we return it. + // TODO(gvisor.dev/issue/173): Verify protocol once we return it. // Verify the source address. EXPECT_EQ(src.sll_family, AF_PACKET); EXPECT_EQ(src.sll_ifindex, GetLoopbackIndex()); EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); // This came from the loopback device, so the address is all 0s. for (int i = 0; i < src.sll_halen; i++) { EXPECT_EQ(src.sll_addr[i], 0); @@ -212,7 +240,7 @@ TEST_P(RawPacketTest, Receive) { // Send via a packet socket. TEST_P(RawPacketTest, Send) { - // TODO(b/129292371): Remove once we support packet socket writing. + // TODO(gvisor.dev/issue/173): Remove once we support packet socket writing. SKIP_IF(IsRunningOnGvisor()); // Let's send a UDP packet and receive it using a regular UDP socket. @@ -277,7 +305,7 @@ TEST_P(RawPacketTest, Send) { sizeof(kMessage)); // Send it. - ASSERT_THAT(sendto(socket_, send_buf, sizeof(send_buf), 0, + ASSERT_THAT(sendto(s_, send_buf, sizeof(send_buf), 0, reinterpret_cast<struct sockaddr*>(&dest), sizeof(dest)), SyscallSucceedsWithValue(sizeof(send_buf))); @@ -286,13 +314,13 @@ TEST_P(RawPacketTest, Send) { pfd.fd = udp_sock.get(); pfd.events = POLLIN; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); - pfd.fd = socket_; + pfd.fd = s_; pfd.events = POLLIN; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 5000), SyscallSucceedsWithValue(1)); // Receive on the packet socket. char recv_buf[sizeof(send_buf)]; - ASSERT_THAT(recv(socket_, recv_buf, sizeof(recv_buf), 0), + ASSERT_THAT(recv(s_, recv_buf, sizeof(recv_buf), 0), SyscallSucceedsWithValue(sizeof(recv_buf))); ASSERT_EQ(memcmp(recv_buf, send_buf, sizeof(send_buf)), 0); @@ -309,6 +337,318 @@ TEST_P(RawPacketTest, Send) { EXPECT_EQ(src.sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } +// Check that setting SO_RCVBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketRecvBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum receive buf size by trying to set it to zero. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_RCVBUF above max is clamped to the maximum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketRecvBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover max buf size by trying to set the largest possible buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_RCVBUF min <= kRcvBufSz <= max is honored. +TEST_P(RawPacketTest, SetSocketRecvBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover max buf size by trying to set a really large buffer size. + constexpr int kRcvBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set a zero size receive buffer + // size. + // See: + // https://github.com/torvalds/linux/blob/a5dc8300df75e8b8384b4c82225f1e4a0b4d9b55/net/core/sock.c#L820 + constexpr int kRcvBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove when Netstack matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + ASSERT_EQ(quarter_sz, val); +} + +// Check that setting SO_SNDBUF below min is clamped to the minimum +// receive buffer size. +TEST_P(RawPacketTest, SetSocketSendBufBelowMin) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + + // Linux doubles the value so let's use a value that when doubled will still + // be smaller than min. + int below_min = min / 2 - 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &below_min, sizeof(below_min)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + ASSERT_EQ(min, val); +} + +// Check that setting SO_SNDBUF above max is clamped to the maximum +// send buffer size. +TEST_P(RawPacketTest, SetSocketSendBufAboveMax) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + int max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + + int above_max = max + 1; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &above_max, sizeof(above_max)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(max, val); +} + +// Check that setting SO_SNDBUF min <= kSndBufSz <= max is honored. +TEST_P(RawPacketTest, SetSocketSendBuf) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int max = 0; + int min = 0; + { + // Discover maximum buffer size by trying to set it to a large value. + constexpr int kSndBufSz = 0xffffffff; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + max = 0; + socklen_t max_len = sizeof(max); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &max, &max_len), + SyscallSucceeds()); + } + + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kSndBufSz = 0; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &kSndBufSz, sizeof(kSndBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &min, &min_len), + SyscallSucceeds()); + } + + int quarter_sz = min + (max - min) / 4; + ASSERT_THAT( + setsockopt(s_, SOL_SOCKET, SO_SNDBUF, &quarter_sz, sizeof(quarter_sz)), + SyscallSucceeds()); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_SNDBUF, &val, &val_len), + SyscallSucceeds()); + + // Linux doubles the value set by SO_SNDBUF/SO_RCVBUF. + // TODO(gvisor.dev/issue/2926): Remove the gvisor special casing when Netstack + // matches linux behavior. + if (!IsRunningOnGvisor()) { + quarter_sz *= 2; + } + + ASSERT_EQ(quarter_sz, val); +} + +TEST_P(RawPacketTest, GetSocketError) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); +} + +TEST_P(RawPacketTest, GetSocketErrorBind) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + { + // Bind to the loopback device. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = GetLoopbackIndex(); + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallSucceeds()); + + // SO_ERROR should return no errors. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } + + { + // Now try binding to an invalid interface. + struct sockaddr_ll bind_addr = {}; + bind_addr.sll_family = AF_PACKET; + bind_addr.sll_protocol = htons(GetParam()); + bind_addr.sll_ifindex = 0xffff; // Just pick a really large number. + + // Binding should fail with EINVAL + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallFailsWithErrno(ENODEV)); + + // SO_ERROR does not return error when the device is invalid. + // On Linux there is just one odd ball condition where this can return + // an error where the device was valid and then removed or disabled + // between the first check for index and the actual registration of + // the packet endpoint. On Netstack this is not possible as the stack + // global mutex is held during registration and check. + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ERROR, &val, &val_len), + SyscallSucceeds()); + ASSERT_EQ(val, 0); + } +} + +#ifndef __fuchsia__ + +TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + // + // gVisor returns no error on SO_DETACH_FILTER even if there is no filter + // attached unlike linux which does return ENOENT in such cases. This is + // because gVisor doesn't support SO_ATTACH_FILTER and just silently returns + // success. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawPacketTest, GetSocketDetachFilter) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index d73d392d5..d6b875dbf 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -754,8 +754,53 @@ TEST(ProcCpuinfo, RequiredFieldsArePresent) { } } -TEST(ProcCpuinfo, DeniesWrite) { - EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); +TEST(ProcCpuinfo, DeniesWriteNonRoot) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER))); + + // Do setuid in a separate thread so that after finishing this test, the + // process can still open files the test harness created before starting this + // test. Otherwise, the files are created by root (UID before the test), but + // cannot be opened by the `uid` set below after the test. After calling + // setuid(non-zero-UID), there is no way to get root privileges back. + ScopedThread([&] { + // Use syscall instead of glibc setuid wrapper because we want this setuid + // call to only apply to this task. POSIX threads, however, require that all + // threads have the same UIDs, so using the setuid wrapper sets all threads' + // real UID. + // Also drops capabilities. + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); + // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in + // kernfs. + if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { + EXPECT_THAT(truncate("/proc/cpuinfo", 123), + SyscallFailsWithErrno(EACCES)); + } + }); +} + +// With root privileges, it is possible to open /proc/cpuinfo with write mode, +// but all write operations will return EIO. +TEST(ProcCpuinfo, DeniesWriteRoot) { + // VFS1 does not behave differently for root/non-root. + SKIP_IF(IsRunningWithVFS1()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_FOWNER))); + + int fd; + EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds()); + if (fd > 0) { + EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO)); + EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO)); + } + // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in + // kernfs. + if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { + if (fd > 0) { + EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO)); + } + EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO)); + } } // Sanity checks that uptime is present. @@ -1967,30 +2012,41 @@ void CheckDuplicatesRecursively(std::string path) { errno = 0; struct dirent* dp = readdir(dir); if (dp == nullptr) { + // Linux will return EINVAL when calling getdents on a /proc/tid/net + // file corresponding to a zombie task. + // See fs/proc/proc_net.c:proc_tgid_net_readdir(). + // + // We just ignore the directory in this case. + if (errno == EINVAL && absl::StartsWith(path, "/proc/") && + absl::EndsWith(path, "/net")) { + break; + } + + // Otherwise, no errors are allowed. ASSERT_EQ(errno, 0) << path; break; // We're done. } - if (strcmp(dp->d_name, ".") == 0 || strcmp(dp->d_name, "..") == 0) { + const std::string name = dp->d_name; + + if (name == "." || name == "..") { continue; } // Ignore a duplicate entry if it isn't the last attempt. if (i == max_attempts - 1) { - ASSERT_EQ(children.find(std::string(dp->d_name)), children.end()) - << absl::StrCat(path, "/", dp->d_name); - } else if (children.find(std::string(dp->d_name)) != children.end()) { + ASSERT_EQ(children.find(name), children.end()) + << absl::StrCat(path, "/", name); + } else if (children.find(name) != children.end()) { std::cerr << "Duplicate entry: " << i << ":" - << absl::StrCat(path, "/", dp->d_name) << std::endl; + << absl::StrCat(path, "/", name) << std::endl; success = false; break; } - children.insert(std::string(dp->d_name)); - - ASSERT_NE(dp->d_type, DT_UNKNOWN); + children.insert(name); if (dp->d_type == DT_DIR) { - child_dirs.push_back(std::string(dp->d_name)); + child_dirs.push_back(name); } } if (success) { diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index aabfa6955..f9392b9e0 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -634,6 +634,11 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // Verify this by setting ICRNL (which rewrites input \r to \n) and verify that // it has no effect on the master. TEST_F(PtyTest, MasterTermiosUnchangable) { + struct kernel_termios master_termios = {}; + EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); + master_termios.c_lflag |= ICRNL; + EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); + char c = '\r'; ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 14a4af980..1d7dbefdb 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -25,16 +25,26 @@ namespace gvisor { namespace testing { -// These tests should be run as root. namespace { +// StealTTY tests whether privileged processes can steal controlling terminals. +// If the stealing process has CAP_SYS_ADMIN in the root user namespace, the +// test ensures that stealing works. If it has non-root CAP_SYS_ADMIN, it +// ensures stealing fails. TEST(JobControlRootTest, StealTTY) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - // Make this a session leader, which also drops the controlling terminal. - // In the gVisor test environment, this test will be run as the session - // leader already (as the sentry init process). + bool true_root = true; if (!IsRunningOnGvisor()) { + // If running in Linux, we may only have CAP_SYS_ADMIN in a non-root user + // namespace (i.e. we are not truly root). We use init_module as a proxy for + // whether we are true root, as it returns EPERM immediately. + ASSERT_THAT(syscall(SYS_init_module, nullptr, 0, nullptr), SyscallFails()); + true_root = errno != EPERM; + + // Make this a session leader, which also drops the controlling terminal. + // In the gVisor test environment, this test will be run as the session + // leader already (as the sentry init process). ASSERT_THAT(setsid(), SyscallSucceeds()); } @@ -53,8 +63,8 @@ TEST(JobControlRootTest, StealTTY) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); - // We should be able to steal it here. - TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1)); + // We should be able to steal it if we are true root. + TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc index 3fe5a600f..63b686c62 100644 --- a/test/syscalls/linux/pwritev2.cc +++ b/test/syscalls/linux/pwritev2.cc @@ -69,7 +69,7 @@ ssize_t pwritev2(unsigned long fd, const struct iovec* iov, } // This test is the base case where we call pwritev (no offset, no flags). -TEST(Writev2Test, TestBaseCall) { +TEST(Writev2Test, BaseCall) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -97,7 +97,7 @@ TEST(Writev2Test, TestBaseCall) { } // This test is where we call pwritev2 with a positive offset and no flags. -TEST(Pwritev2Test, TestValidPositiveOffset) { +TEST(Pwritev2Test, ValidPositiveOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); std::string prefix(kBufSize, '0'); @@ -129,7 +129,7 @@ TEST(Pwritev2Test, TestValidPositiveOffset) { // This test is the base case where we call writev by using -1 as the offset. // The write should use the file offset, so the test increments the file offset // prior to call pwritev2. -TEST(Pwritev2Test, TestNegativeOneOffset) { +TEST(Pwritev2Test, NegativeOneOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const std::string prefix = "00"; @@ -164,7 +164,7 @@ TEST(Pwritev2Test, TestNegativeOneOffset) { // pwritev2 requires if the RWF_HIPRI flag is passed, the fd must be opened with // O_DIRECT. This test implements a correct call with the RWF_HIPRI flag. -TEST(Pwritev2Test, TestCallWithRWF_HIPRI) { +TEST(Pwritev2Test, CallWithRWF_HIPRI) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -189,47 +189,8 @@ TEST(Pwritev2Test, TestCallWithRWF_HIPRI) { EXPECT_EQ(buf, content); } -// This test checks that pwritev2 can be called with valid flags -TEST(Pwritev2Test, TestCallWithValidFlags) { - SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); - - const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( - GetAbsoluteTestTmpdir(), "", TempPath::kDefaultFileMode)); - const FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); - - std::vector<char> content(kBufSize, '0'); - struct iovec iov; - iov.iov_base = content.data(); - iov.iov_len = content.size(); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/RWF_DSYNC), - SyscallSucceedsWithValue(kBufSize)); - - std::vector<char> buf(content.size()); - EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); - - SetContent(content); - - EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1, - /*offset=*/0, /*flags=*/0x4), - SyscallSucceedsWithValue(kBufSize)); - - ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), - SyscallSucceedsWithValue(content.size())); - - EXPECT_THAT(pread(fd.get(), buf.data(), buf.size(), /*offset=*/0), - SyscallSucceedsWithValue(buf.size())); - - EXPECT_EQ(buf, content); -} - // This test calls pwritev2 with a bad file descriptor. -TEST(Writev2Test, TestBadFile) { +TEST(Writev2Test, BadFile) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); ASSERT_THAT(pwritev2(/*fd=*/-1, /*iov=*/nullptr, /*iovcnt=*/0, /*offset=*/0, /*flags=*/0), @@ -237,7 +198,7 @@ TEST(Writev2Test, TestBadFile) { } // This test calls pwrite2 with an invalid offset. -TEST(Pwritev2Test, TestInvalidOffset) { +TEST(Pwritev2Test, InvalidOffset) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -255,7 +216,7 @@ TEST(Pwritev2Test, TestInvalidOffset) { SyscallFailsWithErrno(EINVAL)); } -TEST(Pwritev2Test, TestUnseekableFileValid) { +TEST(Pwritev2Test, UnseekableFileValid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; @@ -285,7 +246,7 @@ TEST(Pwritev2Test, TestUnseekableFileValid) { // Calling pwritev2 with a non-negative offset calls pwritev. Calling pwritev // with an unseekable file is not allowed. A pipe is used for an unseekable // file. -TEST(Pwritev2Test, TestUnseekableFileInValid) { +TEST(Pwritev2Test, UnseekableFileInvalid) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); int pipe_fds[2]; @@ -304,7 +265,7 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) { EXPECT_THAT(close(pipe_fds[1]), SyscallSucceeds()); } -TEST(Pwritev2Test, TestReadOnlyFile) { +TEST(Pwritev2Test, ReadOnlyFile) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( @@ -323,7 +284,7 @@ TEST(Pwritev2Test, TestReadOnlyFile) { } // This test calls pwritev2 with an invalid flag. -TEST(Pwritev2Test, TestInvalidFlag) { +TEST(Pwritev2Test, InvalidFlag) { SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket.cc index 0116c3e94..8d6e5c913 100644 --- a/test/syscalls/linux/raw_socket_ipv4.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -13,8 +13,12 @@ // limitations under the License. #include <linux/capability.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <netinet/in.h> #include <netinet/ip.h> +#include <netinet/ip6.h> #include <netinet/ip_icmp.h> #include <poll.h> #include <sys/socket.h> @@ -39,7 +43,7 @@ namespace testing { namespace { // Fixture for tests parameterized by protocol. -class RawSocketTest : public ::testing::TestWithParam<int> { +class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { protected: // Creates a socket to be used in tests. void SetUp() override; @@ -50,36 +54,58 @@ class RawSocketTest : public ::testing::TestWithParam<int> { // Sends buf via s_. void SendBuf(const char* buf, int buf_len); - // Sends buf to the provided address via the provided socket. - void SendBufTo(int sock, const struct sockaddr_in& addr, const char* buf, - int buf_len); - // Reads from s_ into recv_buf. void ReceiveBuf(char* recv_buf, size_t recv_buf_len); - int Protocol() { return GetParam(); } + void ReceiveBufFrom(int sock, char* recv_buf, size_t recv_buf_len); + + int Protocol() { return std::get<0>(GetParam()); } + + int Family() { return std::get<1>(GetParam()); } + + socklen_t AddrLen() { + if (Family() == AF_INET) { + return sizeof(sockaddr_in); + } + return sizeof(sockaddr_in6); + } + + int HdrLen() { + if (Family() == AF_INET) { + return sizeof(struct iphdr); + } + // IPv6 raw sockets don't include the header. + return 0; + } // The socket used for both reading and writing. int s_; // The loopback address. - struct sockaddr_in addr_; + struct sockaddr_storage addr_; }; void RawSocketTest::SetUp() { if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) { - ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()), + ASSERT_THAT(socket(Family(), SOCK_RAW, Protocol()), SyscallFailsWithErrno(EPERM)); GTEST_SKIP(); } - ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s_ = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); addr_ = {}; // We don't set ports because raw sockets don't have a notion of ports. - addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); - addr_.sin_family = AF_INET; + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr_); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr_); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_loopback; + } } void RawSocketTest::TearDown() { @@ -96,7 +122,7 @@ TEST_P(RawSocketTest, MultipleCreation) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); ASSERT_THAT(close(s2), SyscallSucceeds()); } @@ -114,7 +140,7 @@ TEST_P(RawSocketTest, ShutdownWriteNoop) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); @@ -129,7 +155,7 @@ TEST_P(RawSocketTest, ShutdownReadNoop) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); @@ -137,9 +163,8 @@ TEST_P(RawSocketTest, ShutdownReadNoop) { constexpr char kBuf[] = "gdg"; ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); - constexpr size_t kReadSize = sizeof(kBuf) + sizeof(struct iphdr); - char c[kReadSize]; - ASSERT_THAT(read(s_, &c, sizeof(c)), SyscallSucceedsWithValue(kReadSize)); + std::vector<char> c(sizeof(kBuf) + HdrLen()); + ASSERT_THAT(read(s_, c.data(), c.size()), SyscallSucceedsWithValue(c.size())); } // Test that listen() fails. @@ -173,7 +198,7 @@ TEST_P(RawSocketTest, GetPeerName) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); struct sockaddr saddr; socklen_t addrlen = sizeof(saddr); @@ -223,7 +248,7 @@ TEST_P(RawSocketTest, ConnectToLoopback) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); } @@ -237,12 +262,33 @@ TEST_P(RawSocketTest, SendWithoutConnectFails) { SyscallFailsWithErrno(EDESTADDRREQ)); } +// Wildcard Bind. +TEST_P(RawSocketTest, BindToWildcard) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + struct sockaddr_storage addr; + addr = {}; + + // We don't set ports because raw sockets don't have a notion of ports. + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = htonl(INADDR_ANY); + } else { + struct sockaddr_in6* sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = in6addr_any; + } + + ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), + SyscallSucceeds()); +} + // Bind to localhost. TEST_P(RawSocketTest, BindToLocalhost) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); } @@ -250,12 +296,18 @@ TEST_P(RawSocketTest, BindToLocalhost) { TEST_P(RawSocketTest, BindToInvalid) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); - struct sockaddr_in bind_addr = {}; - bind_addr.sin_family = AF_INET; - bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + struct sockaddr_storage bind_addr = addr_; + if (Family() == AF_INET) { + struct sockaddr_in* sin = reinterpret_cast<struct sockaddr_in*>(&bind_addr); + sin->sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + } else { + struct sockaddr_in6* sin6 = + reinterpret_cast<struct sockaddr_in6*>(&bind_addr); + memset(&sin6->sin6_addr.s6_addr, 0, sizeof(sin6->sin6_addr.s6_addr)); + sin6->sin6_addr.s6_addr[0] = 1; // 1: - An address that we can't bind to. + } ASSERT_THAT(bind(s_, reinterpret_cast<struct sockaddr*>(&bind_addr), - sizeof(bind_addr)), - SyscallFailsWithErrno(EADDRNOTAVAIL)); + AddrLen()), SyscallFailsWithErrno(EADDRNOTAVAIL)); } // Send and receive an packet. @@ -267,9 +319,9 @@ TEST_P(RawSocketTest, SendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // We should be able to create multiple raw sockets for the same protocol and @@ -278,22 +330,23 @@ TEST_P(RawSocketTest, MultipleSocketReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); int s2; - ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds()); + ASSERT_THAT(s2 = socket(Family(), SOCK_RAW, Protocol()), SyscallSucceeds()); // Arbitrary. constexpr char kBuf[] = "TB10"; ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive it on socket 1. - char recv_buf1[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1, sizeof(recv_buf1))); + std::vector<char> recv_buf1(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf1.data(), recv_buf1.size())); // Receive it on socket 2. - char recv_buf2[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s2, recv_buf2, sizeof(recv_buf2))); + std::vector<char> recv_buf2(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s2, recv_buf2.data(), + recv_buf2.size())); - EXPECT_EQ(memcmp(recv_buf1 + sizeof(struct iphdr), - recv_buf2 + sizeof(struct iphdr), sizeof(kBuf)), + EXPECT_EQ(memcmp(recv_buf1.data() + HdrLen(), + recv_buf2.data() + HdrLen(), sizeof(kBuf)), 0); ASSERT_THAT(close(s2), SyscallSucceeds()); @@ -304,7 +357,7 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -313,9 +366,9 @@ TEST_P(RawSocketTest, SendAndReceiveViaConnect) { SyscallSucceedsWithValue(sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Bind to localhost, then send and receive packets. @@ -323,7 +376,7 @@ TEST_P(RawSocketTest, BindSendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -331,9 +384,9 @@ TEST_P(RawSocketTest, BindSendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Bind and connect to localhost and send/receive packets. @@ -341,10 +394,10 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); // Arbitrary. @@ -352,9 +405,9 @@ TEST_P(RawSocketTest, BindConnectSendAndReceive) { ASSERT_NO_FATAL_FAILURE(SendBuf(kBuf, sizeof(kBuf))); // Receive the packet and make sure it's identical. - char recv_buf[sizeof(kBuf) + sizeof(struct iphdr)]; - ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf, sizeof(recv_buf))); - EXPECT_EQ(memcmp(recv_buf + sizeof(struct iphdr), kBuf, sizeof(kBuf)), 0); + std::vector<char> recv_buf(sizeof(kBuf) + HdrLen()); + ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), kBuf, sizeof(kBuf)), 0); } // Check that setting SO_RCVBUF below min is clamped to the minimum @@ -580,20 +633,16 @@ TEST_P(RawSocketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } -void RawSocketTest::SendBuf(const char* buf, int buf_len) { - ASSERT_NO_FATAL_FAILURE(SendBufTo(s_, addr_, buf, buf_len)); -} - // Test that receive buffer limits are not enforced when the recv buffer is // empty. TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); int min = 0; @@ -616,10 +665,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); EXPECT_EQ( - memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } @@ -631,10 +680,10 @@ TEST_P(RawSocketTest, RecvBufLimitsEmptyRecvBuffer) { RandomizeBuffer(buf.data(), buf.size()); ASSERT_NO_FATAL_FAILURE(SendBuf(buf.data(), buf.size())); // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); EXPECT_EQ( - memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), buf.size()), + memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } } @@ -652,10 +701,10 @@ TEST_P(RawSocketTest, RecvBufLimits) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); ASSERT_THAT( - bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + bind(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); ASSERT_THAT( - connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + connect(s_, reinterpret_cast<struct sockaddr*>(&addr_), AddrLen()), SyscallSucceeds()); int min = 0; @@ -716,16 +765,16 @@ TEST_P(RawSocketTest, RecvBufLimits) { // Verify that the expected number of packets are available to be read. for (int i = 0; i < sent - 1; i++) { // Receive the packet and make sure it's identical. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); ASSERT_NO_FATAL_FAILURE(ReceiveBuf(recv_buf.data(), recv_buf.size())); - EXPECT_EQ(memcmp(recv_buf.data() + sizeof(struct iphdr), buf.data(), + EXPECT_EQ(memcmp(recv_buf.data() + HdrLen(), buf.data(), buf.size()), 0); } // Assert that the last packet is dropped because the receive buffer should // be full after the first four packets. - std::vector<char> recv_buf(buf.size() + sizeof(struct iphdr)); + std::vector<char> recv_buf(buf.size() + HdrLen()); struct iovec iov = {}; iov.iov_base = static_cast<void*>(const_cast<char*>(recv_buf.data())); iov.iov_len = buf.size(); @@ -740,30 +789,79 @@ TEST_P(RawSocketTest, RecvBufLimits) { } } -void RawSocketTest::SendBufTo(int sock, const struct sockaddr_in& addr, - const char* buf, int buf_len) { +void RawSocketTest::SendBuf(const char* buf, int buf_len) { // It's safe to use const_cast here because sendmsg won't modify the iovec or // address. struct iovec iov = {}; iov.iov_base = static_cast<void*>(const_cast<char*>(buf)); iov.iov_len = static_cast<size_t>(buf_len); struct msghdr msg = {}; - msg.msg_name = static_cast<void*>(const_cast<struct sockaddr_in*>(&addr)); - msg.msg_namelen = sizeof(addr); + msg.msg_name = static_cast<void*>(&addr_); + msg.msg_namelen = AddrLen(); msg.msg_iov = &iov; msg.msg_iovlen = 1; msg.msg_control = NULL; msg.msg_controllen = 0; msg.msg_flags = 0; - ASSERT_THAT(sendmsg(sock, &msg, 0), SyscallSucceedsWithValue(buf_len)); + ASSERT_THAT(sendmsg(s_, &msg, 0), SyscallSucceedsWithValue(buf_len)); } void RawSocketTest::ReceiveBuf(char* recv_buf, size_t recv_buf_len) { - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, recv_buf_len)); + ASSERT_NO_FATAL_FAILURE(ReceiveBufFrom(s_, recv_buf, recv_buf_len)); +} + +void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, + size_t recv_buf_len) { + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); +} + +#ifndef __fuchsia__ + +TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + if (IsRunningOnGvisor()) { + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); + return; + } + + constexpr int val = 0; + ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(RawSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + +// AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. +TEST(RawSocketTest, IPv6ProtoRaw) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + // Verify that writing yields EINVAL. + char buf[] = "This is such a weird little edge case"; + struct sockaddr_in6 sin6 = {}; + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = in6addr_loopback; + ASSERT_THAT(sendto(sock, buf, sizeof(buf), 0 /* flags */, + reinterpret_cast<struct sockaddr*>(&sin6), sizeof(sin6)), + SyscallFailsWithErrno(EINVAL)); } -INSTANTIATE_TEST_SUITE_P(AllInetTests, RawSocketTest, - ::testing::Values(IPPROTO_TCP, IPPROTO_UDP)); +INSTANTIATE_TEST_SUITE_P( + AllInetTests, RawSocketTest, + ::testing::Combine(::testing::Values(IPPROTO_TCP, IPPROTO_UDP), + ::testing::Values(AF_INET, AF_INET6))); } // namespace diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc index 0a27506aa..97f0467aa 100644 --- a/test/syscalls/linux/raw_socket_hdrincl.cc +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -167,7 +167,7 @@ TEST_F(RawHDRINCL, NotReadable) { // nothing to be read. char buf[117]; ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EINVAL)); + SyscallFailsWithErrno(EAGAIN)); } // Test that we can connect() to a valid IP (loopback). @@ -178,6 +178,9 @@ TEST_F(RawHDRINCL, ConnectToLoopback) { } TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { + // FIXME(github.dev/issue/3159): Test currently flaky. + SKIP_IF(true); + struct iphdr hdr = LoopbackHeader(); ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), SyscallSucceedsWithValue(sizeof(hdr))); @@ -273,14 +276,17 @@ TEST_F(RawHDRINCL, SendAndReceive) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr iphdr = {}; - memcpy(&iphdr, recv_buf, sizeof(iphdr)); - EXPECT_EQ(iphdr.id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); + EXPECT_NE(iphdr->id, 0); } -// Send and receive a packet with nonzero IP ID. -TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { +// Send and receive a packet where the sendto address is not the same as the +// provided destination. +TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { + // FIXME(github.dev/issue/3160): Test currently flaky. + SKIP_IF(true); + int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( @@ -292,19 +298,24 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { FileDescriptor udp_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); - // Construct a packet with an IP header, UDP header, and payload. Make the - // payload large enough to force an IP ID to be assigned. - constexpr char kPayload[128] = {}; + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + // Overwrite the IP destination address with an IP we can't get to. + struct iphdr iphdr = {}; + memcpy(&iphdr, packet, sizeof(iphdr)); + iphdr.daddr = 42; + memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload. + // Receive the payload, since sendto should replace the bad destination with + // localhost. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); @@ -318,47 +329,58 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should not be 0, as the packet was more than 68 bytes. - struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); - EXPECT_NE(iphdr->id, 0); + // The packet ID should not be 0, as the packet has DF=0. + struct iphdr recv_iphdr = {}; + memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); + EXPECT_NE(recv_iphdr.id, 0); + // The destination address should be localhost, not the bad IP we set + // initially. + EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); } -// Send and receive a packet where the sendto address is not the same as the -// provided destination. -TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { +// Send and receive a packet w/ the IP_HDRINCL option set. +TEST_F(RawHDRINCL, SendAndReceiveIPHdrIncl) { int port = 40000; if (!IsRunningOnGvisor()) { port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); } - // IPPROTO_RAW sockets are write-only. We'll have to open another socket to - // read what we write. - FileDescriptor udp_sock = + FileDescriptor recv_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + FileDescriptor send_sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + // Enable IP_HDRINCL option so that we can build and send w/ an IP + // header. + constexpr int kSockOptOn = 1; + ASSERT_THAT(setsockopt(send_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // This is not strictly required but we do it to make sure that setting + // IP_HDRINCL on a non IPPROTO_RAW socket does not prevent it from receiving + // packets. + ASSERT_THAT(setsockopt(recv_sock.get(), SOL_IP, IP_HDRINCL, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + // Construct a packet with an IP header, UDP header, and payload. constexpr char kPayload[] = "toto"; char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; ASSERT_TRUE( FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); - // Overwrite the IP destination address with an IP we can't get to. - struct iphdr iphdr = {}; - memcpy(&iphdr, packet, sizeof(iphdr)); - iphdr.daddr = 42; - memcpy(packet, &iphdr, sizeof(iphdr)); socklen_t addrlen = sizeof(addr_); - ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + ASSERT_NO_FATAL_FAILURE(sendto(send_sock.get(), &packet, sizeof(packet), 0, reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); - // Receive the payload, since sendto should replace the bad destination with - // localhost. + // Receive the payload. char recv_buf[sizeof(packet)]; struct sockaddr_in src; socklen_t src_size = sizeof(src); - ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + ASSERT_THAT(recvfrom(recv_sock.get(), recv_buf, sizeof(recv_buf), 0, reinterpret_cast<struct sockaddr*>(&src), &src_size), SyscallSucceedsWithValue(sizeof(packet))); EXPECT_EQ( @@ -368,13 +390,20 @@ TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { // The network stack should have set the source address. EXPECT_EQ(src.sin_family, AF_INET); EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); - // The packet ID should be 0, as the packet is less than 68 bytes. - struct iphdr recv_iphdr = {}; - memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); - EXPECT_EQ(recv_iphdr.id, 0); - // The destination address should be localhost, not the bad IP we set - // initially. - EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); + struct iphdr iphdr = {}; + memcpy(&iphdr, recv_buf, sizeof(iphdr)); + EXPECT_NE(iphdr.id, 0); + + // Also verify that the packet we just sent was not delivered to the + // IPPROTO_RAW socket. + { + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(socket_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallFailsWithErrno(EAGAIN)); + } } } // namespace diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc index 4b2cb0606..d3cc71dbf 100644 --- a/test/syscalls/linux/socket_bind_to_device_sequence.cc +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -180,7 +180,7 @@ class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { private: SocketKind socket_factory_; // devices maps from the device id in the test case to the name of the device. - std::unordered_map<int, string> devices_; + absl::node_hash_map<int, string> devices_; // These are the tunnels that were created for the test and will be destroyed // by the destructor. vector<std::unique_ptr<Tunnel>> tunnels_; @@ -363,10 +363,6 @@ TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { } TEST_P(BindToDeviceSequenceTest, BindWithReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE( BindSocket(/* reusePort */ false, /* reuse_addr */ true)); ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ false, @@ -406,10 +402,6 @@ TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindReuseAddrReusePortThenReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); @@ -436,10 +428,6 @@ TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReusePort) { } TEST_P(BindToDeviceSequenceTest, BindDoubleReuseAddrReusePortThenReuseAddr) { - // FIXME(b/158253835): Support SO_REUSEADDR on TCP sockets. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - ASSERT_NO_FATAL_FAILURE(BindSocket( /* reuse_port */ true, /* reuse_addr */ true, /* bind_to_device */ 0)); ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse_port */ true, diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index a914d9a12..18b9e4b70 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -17,6 +17,7 @@ #include <netinet/tcp.h> #include <poll.h> #include <string.h> +#include <sys/socket.h> #include <atomic> #include <iostream> @@ -686,24 +687,9 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { // be restarted causing the final bind/connect to fail. DisableSave ds; - // TODO(gvisor.dev/issue/1030): Portmanager does not track all 5 tuple - // reservations which causes the bind() to succeed on gVisor but connect - // correctly fails. - if (IsRunningOnGvisor()) { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } else { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); - } + ASSERT_THAT(bind(conn_fd2.get(), + reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + SyscallFailsWithErrno(EADDRINUSE)); // Sleep for a little over the linger timeout to reduce flakiness in // save/restore tests. @@ -1737,6 +1723,171 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, DualStackV6AnyReservesEverything) { ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), test_addr_v4.addr_len), SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrDoesNotReserveV4Any) { + auto const& param = GetParam(); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyReuseAddrListenReservesV4Any) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_dual.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v4 any on the same port with a v4 socket succeeds. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(setsockopt(fd_v4_any.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketMultiProtocolInetLoopbackTest, + DualStackV6AnyWithListenReservesEverything) { + auto const& param = GetParam(); + + // Only TCP sockets are supported. + SKIP_IF((param.type & SOCK_STREAM) == 0); + + // Bind the v6 any on a dual stack socket. + TestAddress const& test_addr_dual = V6Any(); + sockaddr_storage addr_dual = test_addr_dual.addr; + const FileDescriptor fd_dual = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_dual.family(), param.type, 0)); + ASSERT_THAT(bind(fd_dual.get(), reinterpret_cast<sockaddr*>(&addr_dual), + test_addr_dual.addr_len), + SyscallSucceeds()); + + ASSERT_THAT(listen(fd_dual.get(), 5), SyscallSucceeds()); + + // Get the port that we bound. + socklen_t addrlen = test_addr_dual.addr_len; + ASSERT_THAT(getsockname(fd_dual.get(), + reinterpret_cast<sockaddr*>(&addr_dual), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(test_addr_dual.family(), addr_dual)); + + // Verify that binding the v6 loopback with the same port fails. + TestAddress const& test_addr_v6 = V6Loopback(); + sockaddr_storage addr_v6 = test_addr_v6.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v6.family(), &addr_v6, port)); + const FileDescriptor fd_v6 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v6.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v6.get(), reinterpret_cast<sockaddr*>(&addr_v6), + test_addr_v6.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v6 socket + // fails. + TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); + sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; + ASSERT_NO_ERRNO( + SetAddrPort(test_addr_v4_mapped.family(), &addr_v4_mapped, port)); + const FileDescriptor fd_v4_mapped = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_mapped.family(), param.type, 0)); + ASSERT_THAT( + bind(fd_v4_mapped.get(), reinterpret_cast<sockaddr*>(&addr_v4_mapped), + test_addr_v4_mapped.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 loopback on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4 = V4Loopback(); + sockaddr_storage addr_v4 = test_addr_v4.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4.family(), &addr_v4, port)); + const FileDescriptor fd_v4 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr_v4.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4.get(), reinterpret_cast<sockaddr*>(&addr_v4), + test_addr_v4.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); + + // Verify that binding the v4 any on the same port with a v4 socket + // fails. + TestAddress const& test_addr_v4_any = V4Any(); + sockaddr_storage addr_v4_any = test_addr_v4_any.addr; + ASSERT_NO_ERRNO(SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, port)); + const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( + Socket(test_addr_v4_any.family(), param.type, 0)); + ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), + test_addr_v4_any.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { @@ -1798,10 +1949,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6OnlyV6AnyReservesV6) { TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v6 loopback on a dual stack socket. TestAddress const& test_addr = V6Loopback(); @@ -1864,17 +2011,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReserved) { test_addr_v6.addr_len), SyscallFailsWithErrno(EADDRINUSE)); - // Verify that binding the v4 any with the same port fails. - TestAddress const& test_addr_v4_any = V4Any(); - sockaddr_storage addr_v4_any = test_addr_v4_any.addr; - ASSERT_NO_ERRNO( - SetAddrPort(test_addr_v4_any.family(), &addr_v4_any, ephemeral_port)); - const FileDescriptor fd_v4_any = ASSERT_NO_ERRNO_AND_VALUE( - Socket(test_addr_v4_any.family(), param.type, 0)); - ASSERT_THAT(bind(fd_v4_any.get(), reinterpret_cast<sockaddr*>(&addr_v4_any), - test_addr_v4_any.addr_len), - SyscallFailsWithErrno(EADDRINUSE)); - // Verify that we can still bind the v4 loopback on the same port. TestAddress const& test_addr_v4_mapped = V4MappedLoopback(); sockaddr_storage addr_v4_mapped = test_addr_v4_mapped.addr; @@ -1963,10 +2099,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V6EphemeralPortReservedReuseAddr) { TEST_P(SocketMultiProtocolInetLoopbackTest, V4MappedEphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v4 loopback on a dual stack socket. TestAddress const& test_addr = V4MappedLoopback(); @@ -2152,10 +2284,6 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReserved) { auto const& param = GetParam(); - // FIXME(b/76031995): Support disabling SO_REUSEADDR for TCP sockets and make - // it disabled by default. - SKIP_IF(IsRunningOnGvisor() && param.type == SOCK_STREAM); - for (int i = 0; true; i++) { // Bind the v4 loopback on a v4 socket. TestAddress const& test_addr = V4Loopback(); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 2324c7f6a..791e2bd51 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -82,8 +82,11 @@ using SocketInetLoopbackTest = ::testing::TestWithParam<TestParam>; // This test verifies that connect returns EADDRNOTAVAIL if all local ephemeral // ports are already in use for a given destination ip/port. +// // We disable S/R because this test creates a large number of sockets. -TEST_P(SocketInetLoopbackTest, TestTCPPortExhaustion_NoRandomSave) { +// +// FIXME(b/162475855): This test is failing reliably. +TEST_P(SocketInetLoopbackTest, DISABLED_TestTCPPortExhaustion_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 796598eee..1c7b0cf90 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -425,10 +425,6 @@ TEST_P(IPUnboundSocketTest, InsufficientBufferTOS) { TEST_P(IPUnboundSocketTest, ReuseAddrDefault) { auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - // FIXME(b/76031995): gVisor TCP sockets default to SO_REUSEADDR enabled. - SKIP_IF(IsRunningOnGvisor() && - (GetParam().type & SOCK_STREAM) == SOCK_STREAM); - int get = -1; socklen_t get_sz = sizeof(get); ASSERT_THAT( diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 15d4b85a7..5f8d7f981 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <linux/ethtool.h> #include <linux/netlink.h> #include <linux/rtnetlink.h> #include <linux/sockios.h> @@ -49,6 +50,7 @@ TEST(NetdeviceTest, Loopback) { // Check that the loopback is zero hardware address. ASSERT_THAT(ioctl(sock.get(), SIOCGIFHWADDR, &ifr), SyscallSucceeds()); + EXPECT_EQ(ifr.ifr_hwaddr.sa_family, ARPHRD_LOOPBACK); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[0], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[1], 0); EXPECT_EQ(ifr.ifr_hwaddr.sa_data[2], 0); @@ -178,6 +180,27 @@ TEST(NetdeviceTest, InterfaceMTU) { EXPECT_GT(ifr.ifr_mtu, 0); } +TEST(NetdeviceTest, EthtoolGetTSInfo) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + + struct ethtool_ts_info tsi = {}; + tsi.cmd = ETHTOOL_GET_TS_INFO; // Get NIC's Timestamping capabilities. + + // Prepare the request. + struct ifreq ifr = {}; + snprintf(ifr.ifr_name, IFNAMSIZ, "lo"); + ifr.ifr_data = (void*)&tsi; + + // Check that SIOCGIFMTU returns a nonzero MTU. + if (IsRunningOnGvisor()) { + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), + SyscallFailsWithErrno(EOPNOTSUPP)); + return; + } + ASSERT_THAT(ioctl(sock.get(), SIOCETHTOOL, &ifr), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index fbe61c5a0..e6647a1c3 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -595,7 +595,7 @@ TEST(NetlinkRouteTest, GetRouteRequest) { ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); uint32_t port = ASSERT_NO_ERRNO_AND_VALUE(NetlinkPortID(fd.get())); - struct __attribute__((__packed__)) request { + struct request { struct nlmsghdr hdr; struct rtmsg rtm; struct nlattr nla; diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc index 92eec0449..4afed6d08 100644 --- a/test/syscalls/linux/sticky.cc +++ b/test/syscalls/linux/sticky.cc @@ -40,11 +40,17 @@ namespace { TEST(StickyTest, StickyBitPermDenied) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY)); - ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -62,7 +68,13 @@ TEST(StickyTest, StickyBitPermDenied) { syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1), SyscallSucceeds()); - EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR), + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "file", 0), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), + SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallFailsWithErrno(EPERM)); }); } @@ -70,10 +82,17 @@ TEST(StickyTest, StickyBitPermDenied) { TEST(StickyTest, StickyBitSameUID) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - std::string path = JoinPath(dir.path(), "NewDir"); - ASSERT_THAT(mkdir(path.c_str(), 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -89,18 +108,29 @@ TEST(StickyTest, StickyBitSameUID) { SyscallSucceeds()); // We still have the same EUID. - EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds()); + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds()); }); } TEST(StickyTest, StickyBitCapFOWNER) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); - auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - EXPECT_THAT(chmod(dir.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); - const FileDescriptor dirfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_DIRECTORY)); - ASSERT_THAT(mkdirat(dirfd.get(), "NewDir", 0755), SyscallSucceeds()); + const TempPath parent = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(chmod(parent.path().c_str(), 0777 | S_ISVTX), SyscallSucceeds()); + + // After changing credentials below, we need to use an open fd to make + // modifications in the parent dir, because there is no guarantee that we will + // still have the ability to open it. + const FileDescriptor parent_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(parent.path(), O_DIRECTORY)); + ASSERT_THAT(openat(parent_fd.get(), "file", O_CREAT), SyscallSucceeds()); + ASSERT_THAT(mkdirat(parent_fd.get(), "dir", 0777), SyscallSucceeds()); + ASSERT_THAT(symlinkat("xyz", parent_fd.get(), "link"), SyscallSucceeds()); // Drop privileges and change IDs only in child thread, or else this parent // thread won't be able to open some log files after the test ends. @@ -117,8 +147,12 @@ TEST(StickyTest, StickyBitCapFOWNER) { SyscallSucceeds()); EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true)); - EXPECT_THAT(unlinkat(dirfd.get(), "NewDir", AT_REMOVEDIR), + EXPECT_THAT(renameat(parent_fd.get(), "file", parent_fd.get(), "file2"), + SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "file2", 0), SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "dir", AT_REMOVEDIR), SyscallSucceeds()); + EXPECT_THAT(unlinkat(parent_fd.get(), "link", 0), SyscallSucceeds()); }); } } // namespace diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index a4d2953e1..0cea7d11f 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,6 +13,9 @@ // limitations under the License. #include <fcntl.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -1559,6 +1562,63 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { } } +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + // Program generated using sudo tcpdump -i lo tcp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000006}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000006}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/timerfd.cc b/test/syscalls/linux/timerfd.cc index 86ed87b7c..c4f8fdd7a 100644 --- a/test/syscalls/linux/timerfd.cc +++ b/test/syscalls/linux/timerfd.cc @@ -204,16 +204,33 @@ TEST_P(TimerfdTest, SetAbsoluteTime) { EXPECT_EQ(1, val); } -TEST_P(TimerfdTest, IllegalReadWrite) { +TEST_P(TimerfdTest, IllegalSeek) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + if (!IsRunningWithVFS1()) { + EXPECT_THAT(lseek(tfd.get(), 0, SEEK_SET), SyscallFailsWithErrno(ESPIPE)); + } +} + +TEST_P(TimerfdTest, IllegalPread) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + int val; + EXPECT_THAT(pread(tfd.get(), &val, sizeof(val), 0), + SyscallFailsWithErrno(ESPIPE)); +} + +TEST_P(TimerfdTest, IllegalPwrite) { + auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), 0)); + EXPECT_THAT(pwrite(tfd.get(), "x", 1, 0), SyscallFailsWithErrno(ESPIPE)); + if (!IsRunningWithVFS1()) { + } +} + +TEST_P(TimerfdTest, IllegalWrite) { auto const tfd = ASSERT_NO_ERRNO_AND_VALUE(TimerfdCreate(GetParam(), TFD_NONBLOCK)); uint64_t val = 0; - EXPECT_THAT(PreadFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); - EXPECT_THAT(WriteFd(tfd.get(), &val, sizeof(val)), + EXPECT_THAT(write(tfd.get(), &val, sizeof(val)), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(PwriteFd(tfd.get(), &val, sizeof(val), 0), - SyscallFailsWithErrno(ESPIPE)); } std::string PrintClockId(::testing::TestParamInfo<int> info) { diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc index fcdba7279..54a0594f7 100644 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ b/test/syscalls/linux/udp_socket_errqueue_test_case.cc @@ -47,7 +47,7 @@ TEST_P(UdpSocketTest, ErrorQueue) { msg.msg_controllen = sizeof(cmsgbuf); // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(s_, &msg, MSG_ERRQUEUE), + EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), SyscallFailsWithErrno(EAGAIN)); } diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc index cc1db3de8..60c48ed6e 100644 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ b/test/syscalls/linux/udp_socket_test_cases.cc @@ -16,12 +16,16 @@ #include <arpa/inet.h> #include <fcntl.h> +#ifndef __fuchsia__ +#include <linux/filter.h> +#endif // __fuchsia__ #include <netinet/in.h> #include <poll.h> #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/types.h> +#include "absl/strings/str_format.h" #ifndef SIOCGSTAMP #include <linux/sockios.h> #endif @@ -30,8 +34,11 @@ #include "absl/base/macros.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -54,517 +61,470 @@ uint16_t* Port(struct sockaddr_storage* addr) { return nullptr; } +// Sets addr port to "port". +void SetPort(struct sockaddr_storage* addr, uint16_t port) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + sin->sin_port = port; + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + sin6->sin6_port = port; + break; + } + } +} + void UdpSocketTest::SetUp() { - int type; + addrlen_ = GetAddrLength(); + + bind_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); + bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + + sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); +} + +int UdpSocketTest::GetFamily() { if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; - auto sin = reinterpret_cast<struct sockaddr_in*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - type = AF_INET6; - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&anyaddr_storage_); - addrlen_ = sizeof(*sin6); - if (GetParam() == AddressFamily::kIpv6) { - sin6->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - sin6->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } + return AF_INET; } - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); + return AF_INET6; +} - ASSERT_THAT(t_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); +PosixError UdpSocketTest::BindLoopback() { + bind_addr_storage_ = InetLoopbackAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} - memset(&anyaddr_storage_, 0, sizeof(anyaddr_storage_)); - anyaddr_ = reinterpret_cast<struct sockaddr*>(&anyaddr_storage_); - anyaddr_->sa_family = type; +PosixError UdpSocketTest::BindAny() { + bind_addr_storage_ = InetAnyAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} - if (gvisor::testing::IsRunningOnGvisor()) { - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ports_[i] = TestPort + i; - } - } else { - // When not under gvisor, use utility function to pick port. Assert that - // all ports are different. - std::string error; - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - // Find an unused port, we specify port 0 to allow the kernel to provide - // the port. - bool unique = true; - do { - ports_[i] = ASSERT_NO_ERRNO_AND_VALUE(PortAvailable( - 0, AddressFamily::kDualStack, SocketType::kUdp, false)); - ASSERT_GT(ports_[i], 0); - for (size_t j = 0; j < i; ++j) { - if (ports_[j] == ports_[i]) { - unique = false; - break; - } - } - } while (!unique); - } +PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { + socklen_t len = sizeof(bind_addr_storage_); + + // Bind, then check that we get the right address. + RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); + + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); + + if (addrlen_ != len) { + return PosixError( + EINVAL, + absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); } + return PosixError(0); +} - // Initialize the sockaddrs. - for (size_t i = 0; i < ABSL_ARRAYSIZE(addr_); ++i) { - memset(&addr_storage_[i], 0, sizeof(addr_storage_[i])); - - addr_[i] = reinterpret_cast<struct sockaddr*>(&addr_storage_[i]); - addr_[i]->sa_family = type; - - switch (type) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr_[i]); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(ports_[i]); - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr_[i]); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(ports_[i]); - break; - } - } +socklen_t UdpSocketTest::GetAddrLength() { + struct sockaddr_storage addr; + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + return sizeof(*sin); } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + return sizeof(*sin6); } -TEST_P(UdpSocketTest, Creation) { - int type = AF_INET6; - if (GetParam() == AddressFamily::kIpv4) { - type = AF_INET; +sockaddr_storage UdpSocketTest::InetAnyAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + sin->sin_port = htons(0); + return addr; } - int s_; + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = IN6ADDR_ANY_INIT; + sin6->sin6_port = htons(0); + return addr; +} + +sockaddr_storage UdpSocketTest::InetLoopbackAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(0); + return addr; + } + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(0); + return addr; +} - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, IPPROTO_UDP), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); +void UdpSocketTest::Disconnect(int sockfd) { + sockaddr_storage addr_storage = InetAnyAddr(); + sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + socklen_t addrlen = sizeof(addr_storage); - ASSERT_THAT(s_ = socket(type, SOCK_DGRAM, 0), SyscallSucceeds()); - EXPECT_THAT(close(s_), SyscallSucceeds()); + addr->sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); - ASSERT_THAT(s_ = socket(type, SOCK_STREAM, IPPROTO_UDP), SyscallFails()); + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, Creation) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); } TEST_P(UdpSocketTest, Getsockname) { // Check that we're not bound. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, anyaddr_, addrlen_), 0); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), + 0); - // Bind, then check that we get the right address. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); } TEST_P(UdpSocketTest, Getpeername) { + ASSERT_NO_ERRNO(BindLoopback()); + // Check that we're not connected. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); // Connect, then check that we get the right address. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); } TEST_P(UdpSocketTest, SendNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + // Do send & write, they must fail. char buf[512]; - EXPECT_THAT(send(s_, buf, sizeof(buf), 0), + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), SyscallFailsWithErrno(EDESTADDRREQ)); - EXPECT_THAT(write(s_, buf, sizeof(buf)), SyscallFailsWithErrno(EDESTADDRREQ)); + EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EDESTADDRREQ)); // Use sendto. - ASSERT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Check that we're bound now. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_NE(*Port(&addr), 0); } TEST_P(UdpSocketTest, ConnectBinds) { + ASSERT_NO_ERRNO(BindLoopback()); + // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Check that we're bound now. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); EXPECT_NE(*Port(&addr), 0); } TEST_P(UdpSocketTest, ReceiveNotBound) { char buf[512]; - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); } TEST_P(UdpSocketTest, Bind) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // Try to bind again. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EINVAL)); // Check that we're still bound to the original address. struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[0], addrlen_), 0); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); } TEST_P(UdpSocketTest, BindInUse) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // Try to bind again. - EXPECT_THAT(bind(t_, addr_[0], addrlen_), SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(UdpSocketTest, ReceiveAfterConnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - // Send from t_ to s_. + // Send from sock_ to bind_ char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); - - // Get the address s_ was bound to during connect. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); + ASSERT_NO_ERRNO(BindLoopback()); for (int i = 0; i < 2; i++) { - // Send from t_ to s_. + // Connet sock_ to bound address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from sock to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, + + ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, reinterpret_cast<sockaddr*>(&addr), addrlen), SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - // Disconnect s_. - struct sockaddr addr = {}; - addr.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); - // Connect s_ loopback:TestPort. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + // Disconnect sock_. + struct sockaddr unspec = {}; + unspec.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), + SyscallSucceeds()); } } TEST_P(UdpSocketTest, Connect) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Check that we're connected to the right peer. struct sockaddr_storage peer; socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[0], addrlen_), 0); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); // Try to bind after connect. - EXPECT_THAT(bind(s_, addr_[1], addrlen_), SyscallFailsWithErrno(EINVAL)); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); + + struct sockaddr_storage bind2_storage = InetLoopbackAddr(); + struct sockaddr* bind2_addr = + reinterpret_cast<struct sockaddr*>(&bind2_storage); + FileDescriptor bind2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); // Check that peer name changed. peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); + EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); } -void ConnectAny(AddressFamily family, int sockfd, uint16_t port) { - struct sockaddr_storage addr = {}; - - // Precondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr any = IN6ADDR_ANY_INIT; - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &any, sizeof(in6_addr)), 0); - } - - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } - - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } - - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } +TEST_P(UdpSocketTest, ConnectAnyZero) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - } - - // Postcondition check. - { - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback; - if (family == AddressFamily::kIpv6) { - loopback = IN6ADDR_LOOPBACK_INIT; - } else { - TestAddress const& v4_mapped_loopback = V4MappedLoopback(); - loopback = reinterpret_cast<const struct sockaddr_in6*>( - &v4_mapped_loopback.addr) - ->sin6_addr; - } - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); - addrlen = sizeof(addr); - if (port == 0) { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - } else { - EXPECT_THAT( - getpeername(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - } - } + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); } -TEST_P(UdpSocketTest, ConnectAny) { ConnectAny(GetParam(), s_, 0); } - TEST_P(UdpSocketTest, ConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - ConnectAny(GetParam(), s_, port); + ASSERT_NO_ERRNO(BindAny()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); } -void DisconnectAfterConnectAny(AddressFamily family, int sockfd, int port) { - struct sockaddr_storage addr = {}; +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); - struct sockaddr_storage baddr = {}; - if (family == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin_family = AF_INET; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - addr_in->sin_port = port; - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addrlen = sizeof(*addr_in); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - if (family == AddressFamily::kIpv6) { - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } else { - TestAddress const& v4_mapped_any = V4MappedAny(); - addr_in->sin6_addr = - reinterpret_cast<const struct sockaddr_in6*>(&v4_mapped_any.addr) - ->sin6_addr; - } - } + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); - // TODO(b/138658473): gVisor doesn't allow connecting to the zero port. - if (port == 0) { - SKIP_IF(IsRunningOnGvisor()); - } + Disconnect(sock_.get()); +} - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&baddr), addrlen), - SyscallSucceeds()); - // Now the socket is bound to the loopback address. +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - // Disconnect - addrlen = sizeof(addr); - addr.ss_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceeds()); + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - if (family == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } + Disconnect(sock_.get()); } -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - DisconnectAfterConnectAny(GetParam(), s_, 0); -} +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_NO_ERRNO(BindLoopback()); -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - DisconnectAfterConnectAny(GetParam(), s_, port); -} + // Bind to the next port above bind_. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); + struct sockaddr_storage unspec = {}; + unspec.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), + sizeof(unspec.ss_family)), + SyscallSucceeds()); // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + socklen_t addrlen = sizeof(unspec); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); + EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), + EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), SyscallFailsWithErrno(ENOTCONN)); } TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { - struct sockaddr_storage baddr = {}; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); - // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindAny()); - struct sockaddr_storage addr = {}; + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + + // Connect the socket. + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); // If the socket is bound to ANY and connected to a loopback address, // getsockname() has to return the loopback address. if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(&addr); + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); EXPECT_EQ(addrlen, sizeof(*addr_out)); EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(&addr); + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; EXPECT_EQ(addrlen, sizeof(*addr_out)); EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); @@ -572,91 +532,97 @@ TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { } TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - struct sockaddr_storage baddr = {}; - socklen_t addrlen; - auto port = *Port(reinterpret_cast<struct sockaddr_storage*>(addr_[1])); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_in = reinterpret_cast<struct sockaddr_in*>(&baddr); - addr_in->sin_family = AF_INET; - addr_in->sin_port = port; - addr_in->sin_addr.s_addr = htonl(INADDR_ANY); - } else { - auto addr_in = reinterpret_cast<struct sockaddr_in6*>(&baddr); - addr_in->sin6_family = AF_INET6; - addr_in->sin6_port = port; - addr_in->sin6_scope_id = 0; - addr_in->sin6_addr = IN6ADDR_ANY_INIT; - } - ASSERT_THAT(bind(s_, reinterpret_cast<sockaddr*>(&baddr), addrlen_), - SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + // Connect the socket. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); + Disconnect(sock_.get()); // Check that we're still bound. - addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); + EXPECT_EQ(memcmp(&addr, any, addrlen), 0); addrlen = sizeof(addr); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); } TEST_P(UdpSocketTest, Disconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + for (int i = 0; i < 2; i++) { // Try to connect again. - EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Check that we're connected to the right peer. struct sockaddr_storage peer; socklen_t peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); // Try to disconnect. struct sockaddr_storage addr = {}; addr.ss_family = AF_UNSPEC; - EXPECT_THAT( - connect(s_, reinterpret_cast<sockaddr*>(&addr), sizeof(addr.ss_family)), - SyscallSucceeds()); + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), + sizeof(addr.ss_family)), + SyscallSucceeds()); peerlen = sizeof(peer); - EXPECT_THAT(getpeername(s_, reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); // Check that we're still bound. socklen_t addrlen = sizeof(addr); - EXPECT_THAT(getsockname(s_, reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), 0); + EXPECT_EQ(*Port(&addr), *Port(&any_storage)); } } TEST_P(UdpSocketTest, ConnectBadAddress) { struct sockaddr addr = {}; - addr.sa_family = addr_[0]->sa_family; - ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), + addr.sa_family = GetFamily(); + ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), SyscallFailsWithErrno(EINVAL)); } TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage addr_storage = InetAnyAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Send to a different destination than we're connected to. char buf[512]; - EXPECT_THAT(sendto(s_, buf, sizeof(buf), 0, addr_[1], addrlen_), + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); } @@ -664,24 +630,27 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + // Connect to loopback:bind_addr_+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + // Bind sock to loopback:bind_addr_+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - struct pollfd pfd = {t_, POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), SyscallSucceedsWithValue(1)); // Receive the packet. char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), SyscallSucceedsWithValue(0)); } @@ -689,216 +658,252 @@ TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - // Set t_ to non-blocking. + // Set sock to non-blocking. int opts = 0; - ASSERT_THAT(opts = fcntl(t_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(t_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); char buf[3]; - // Send zero length packet from s_ to t_. - ASSERT_THAT(write(s_, buf, 0), SyscallSucceedsWithValue(0)); + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - struct pollfd pfd = {t_, POLLIN, 0}; + struct pollfd pfd = {sock_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); // Receive the packet. char received[3]; - EXPECT_THAT(read(t_, received, sizeof(received)), + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(t_, received, sizeof(received)), + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), SyscallFailsWithErrno(EAGAIN)); } TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Send some data to s_. + // Send some data to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } TEST_P(UdpSocketTest, SendAndReceiveConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Send some data from t_ to s_. + // Bind sock to loopback:TestPort+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - // Send some data from t_ to s_. + // Send some data from sock to bind_. char buf[512]; - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); - // Check that the data isn't_ received because it was sent from a different + // Check that the data isn't received because it was sent from a different // address than we're connected. - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); } TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - // Send some data from t_ to s_. + // Send some data from sock to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Connect to loopback:TestPort+1. - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); // Receive the data. It works because it was sent before the connect. char received[sizeof(buf)]; - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); // Send again. This time it should not be received. - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); - EXPECT_THAT(recv(s_, buf, sizeof(buf), MSG_DONTWAIT), + EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); } TEST_P(UdpSocketTest, ReceiveFrom) { - // Bind s_ to loopback:TestPort, and connect to loopback:TestPort+1. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Bind t_ to loopback:TestPort+1. - ASSERT_THAT(bind(t_, addr_[1], addrlen_), SyscallSucceeds()); + // Bind sock to loopback:TestPort+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - // Send some data from t_ to s_. + // Send some data from sock to bind_. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, addr_[0], addrlen_), + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(sizeof(buf))); // Receive the data and sender address. char received[sizeof(buf)]; - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr), &addrlen), + struct sockaddr_storage addr2; + socklen_t addr2len = sizeof(addr2); + EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr2), &addr2len), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); + EXPECT_EQ(addr2len, addrlen_); + EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); } TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(s_, SOMAXCONN), SyscallFailsWithErrno(EOPNOTSUPP)); + ASSERT_THAT(listen(sock_.get(), SOMAXCONN), + SyscallFailsWithErrno(EOPNOTSUPP)); } TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(s_, nullptr, nullptr), SyscallFailsWithErrno(EOPNOTSUPP)); + ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), + SyscallFailsWithErrno(EOPNOTSUPP)); } // This test validates that a read shutdown with pending data allows the read // to proceed with the data before returning EAGAIN. TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - char received[512]; + ASSERT_NO_ERRNO(BindLoopback()); - // Bind t_ to loopback:TestPort+2. - ASSERT_THAT(bind(t_, addr_[2], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Verify that we get EWOULDBLOCK when there is nothing to read. - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); const char* buf = "abc"; - EXPECT_THAT(write(t_, buf, 3), SyscallSucceedsWithValue(3)); + EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); int opts = 0; - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(s_, F_SETFL, opts | O_NONBLOCK), SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(s_, F_GETFL), SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); ASSERT_NE(opts & O_NONBLOCK, 0); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(s_, received, 2, 0), SyscallSucceedsWithValue(2)); + EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); // Because we read less than the entire packet length, since it's a packet // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(s_, received, 1, 0), SyscallFailsWithErrno(EWOULDBLOCK)); + EXPECT_THAT(recv(bind_.get(), received, 1, 0), + SyscallFailsWithErrno(EWOULDBLOCK)); } // This test is validating that even after a socket is shutdown if it's // reconnected it will reset the shutdown state. TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); // Connect the socket, then try to shutdown again. - ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); } @@ -907,24 +912,26 @@ TEST_P(UdpSocketTest, ReadShutdown) { // MSG_DONTWAIT blocks indefinitely. SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); + char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - EXPECT_THAT(recv(s_, received, sizeof(received), 0), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(0)); } @@ -932,93 +939,94 @@ TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without // MSG_DONTWAIT blocks indefinitely. SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); char received[512]; - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - EXPECT_THAT(recv(s_, received, sizeof(received), MSG_DONTWAIT), + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), SyscallFailsWithErrno(EWOULDBLOCK)); ScopedThread t([&] { absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(this->s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); }); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(0)); t.Join(); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(0)); } TEST_P(UdpSocketTest, WriteShutdown) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); } TEST_P(UdpSocketTest, SynchronousReceive) { - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Send some data to s_ from another thread. + // Send some data to bind_ from another thread. char buf[512]; RandomizeBuffer(buf, sizeof(buf)); // Receive the data prior to actually starting the other thread. char received[512]; - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); + EXPECT_THAT( + RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); // Start the thread. ScopedThread t([&] { absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT( - sendto(this->t_, buf, sizeof(buf), 0, this->addr_[0], this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, + this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); }); - EXPECT_THAT(RetryEINTR(recv)(s_, received, sizeof(received), 0), + EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), SyscallSucceedsWithValue(512)); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Send 3 packets from t_ to s_. + // Send 3 packets from sock to bind_. constexpr int psize = 100; char buf[3 * psize]; RandomizeBuffer(buf, sizeof(buf)); for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); } // Receive the data as 3 separate packets. char received[6 * psize]; for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(s_, received + i * psize, 3 * psize, 0), + EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), SyscallSucceedsWithValue(psize)); } EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); } TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Direct writes from t_ to s_. - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + // Direct writes from sock to bind_. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. constexpr size_t kPieceSize = 100; char buf[4 * kPieceSize]; RandomizeBuffer(buf, sizeof(buf)); @@ -1030,7 +1038,8 @@ TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); iov[j].iov_len = kPieceSize; } - ASSERT_THAT(writev(t_, iov, 2), SyscallSucceedsWithValue(2 * kPieceSize)); + ASSERT_THAT(writev(sock_.get(), iov, 2), + SyscallSucceedsWithValue(2 * kPieceSize)); } // Receive the data as 2 separate packets. @@ -1042,17 +1051,17 @@ TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); iov[j].iov_len = kPieceSize; } - ASSERT_THAT(readv(s_, iov, 3), SyscallSucceedsWithValue(2 * kPieceSize)); + ASSERT_THAT(readv(bind_.get(), iov, 3), + SyscallSucceedsWithValue(2 * kPieceSize)); } EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); } TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); - // Send 2 packets from t_ to s_, where each packet's data consists of 2 - // discontiguous iovecs. + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. constexpr size_t kPieceSize = 100; char buf[4 * kPieceSize]; RandomizeBuffer(buf, sizeof(buf)); @@ -1065,11 +1074,12 @@ TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { iov[j].iov_len = kPieceSize; } struct msghdr msg = {}; - msg.msg_name = addr_[0]; + msg.msg_name = bind_addr_; msg.msg_namelen = addrlen_; msg.msg_iov = iov; msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(t_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); + ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); } // Receive the data as 2 separate packets. @@ -1084,85 +1094,87 @@ TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { struct msghdr msg = {}; msg.msg_iov = iov; msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(s_, &msg, 0), SyscallSucceedsWithValue(2 * kPieceSize)); + ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); } EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); } TEST_P(UdpSocketTest, FIONREADShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); } TEST_P(UdpSocketTest, FIONREADWriteShutdown) { int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); const char str[] = "abc"; - ASSERT_THAT(send(s_, str, sizeof(str), 0), + ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), SyscallSucceedsWithValue(sizeof(str))); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, sizeof(str)); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, sizeof(str)); } // NOTE: Do not use `FIONREAD` as test name because it will be replaced by the // corresponding macro and become `0x541B`. TEST_P(UdpSocketTest, Fionread) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // Check that the bound socket with an empty buffer reports an empty first // packet. int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - // Send 3 packets from t_ to s_. + // Send 3 packets from sock to bind_. constexpr int psize = 100; char buf[3 * psize]; RandomizeBuffer(buf, sizeof(buf)); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, psize, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(psize)); + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); @@ -1170,30 +1182,30 @@ TEST_P(UdpSocketTest, Fionread) { // Check that regardless of how many packets are in the queue, the size // reported is that of a single packet. n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, psize); } } TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // Check that the bound socket with an empty buffer reports an empty first // packet. int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - // Send 3 packets from t_ to s_. + // Send 3 packets from sock to bind_. constexpr int psize = 100; char buf[3 * psize]; RandomizeBuffer(buf, sizeof(buf)); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; for (int i = 0; i < 3; ++i) { - ASSERT_THAT(sendto(t_, buf + i * psize, 0, 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(0)); + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(0)); // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet // socket does not cause a poll event to be triggered. @@ -1205,40 +1217,77 @@ TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { // Check that regardless of how many packets are in the queue, the size // reported is that of a single packet. n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); } } TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { int n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - // Bind s_ to loopback:TestPort. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); const char str[] = "abc"; - ASSERT_THAT(send(s_, str, 0, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); - EXPECT_THAT(shutdown(s_, SHUT_RD), SyscallSucceeds()); + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); n = -1; - EXPECT_THAT(ioctl(s_, FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); EXPECT_EQ(n, 0); } +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + TEST_P(UdpSocketTest, SoTimestampOffByDefault) { // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by // hostinet. @@ -1246,7 +1295,7 @@ TEST_P(UdpSocketTest, SoTimestampOffByDefault) { int v = -1; socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), SyscallSucceeds()); ASSERT_EQ(v, kSockOptOff); ASSERT_EQ(optlen, sizeof(v)); @@ -1257,18 +1306,19 @@ TEST_P(UdpSocketTest, SoTimestamp) { // supported by hostinet. SKIP_IF(IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); int v = 1; - ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), SyscallSucceeds()); char buf[3]; - // Send zero length packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + // Send zero length packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); @@ -1282,7 +1332,8 @@ TEST_P(UdpSocketTest, SoTimestamp) { msg.msg_control = cmsgbuf; msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); ASSERT_NE(cmsg, nullptr); @@ -1296,36 +1347,37 @@ TEST_P(UdpSocketTest, SoTimestamp) { ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); } TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(s_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); } TEST_P(UdpSocketTest, TimestampIoctl) { // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. SKIP_IF(IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); // There should be no control messages. char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); // A nonzero timeval should be available via ioctl. struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); } @@ -1333,11 +1385,12 @@ TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. SKIP_IF(IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallFailsWithErrno(ENOENT)); + ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); } // Test that the timestamp accessed via SIOCGSTAMP is still accessible after @@ -1347,33 +1400,35 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // supported by hostinet. SKIP_IF(IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); char buf[3]; - // Send packet from t_ to s_. - ASSERT_THAT(RetryEINTR(write)(t_, buf, sizeof(buf)), + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); - struct pollfd pfd = {s_, POLLIN, 0}; + struct pollfd pfd = {bind_.get(), POLLIN, 0}; ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); // There should be no control messages. char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(s_, recv_buf, sizeof(recv_buf))); + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); // A nonzero timeval should be available via ioctl. struct timeval tv = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); // Enable SO_TIMESTAMP and send a message. int v = 1; - EXPECT_THAT(setsockopt(s_, SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(t_, buf, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), SyscallSucceedsWithValue(1)); @@ -1386,13 +1441,14 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { msg.msg_iovlen = 1; msg.msg_control = cmsgbuf; msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &msg, 0), SyscallSucceedsWithValue(0)); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); ASSERT_NE(cmsg, nullptr); // The ioctl should return the exact same values as before. struct timeval tv2 = {}; - ASSERT_THAT(ioctl(s_, SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); ASSERT_EQ(tv.tv_sec, tv2.tv_sec); ASSERT_EQ(tv.tv_usec, tv2.tv_usec); } @@ -1401,8 +1457,8 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) { // outgoing packets, and that a receiving socket with IP_RECVTOS or // IPV6_RECVTCLASS will create the corresponding control message. TEST_P(UdpSocketTest, SetAndReceiveTOS) { - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Allow socket to receive control message. int recv_level = SOL_IP; @@ -1411,9 +1467,9 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { recv_level = SOL_IPV6; recv_type = IPV6_RECVTCLASS; } - ASSERT_THAT( - setsockopt(s_, recv_level, recv_type, &kSockOptOn, sizeof(kSockOptOn)), - SyscallSucceeds()); + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); // Set socket TOS. int sent_level = recv_level; @@ -1422,9 +1478,9 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { sent_type = IPV6_TCLASS; } int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - ASSERT_THAT( - setsockopt(t_, sent_level, sent_type, &sent_tos, sizeof(sent_tos)), - SyscallSucceeds()); + ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, + sizeof(sent_tos)), + SyscallSucceeds()); // Prepare message to send. constexpr size_t kDataLength = 1024; @@ -1436,7 +1492,7 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { sent_msg.msg_iov = &sent_iov; sent_msg.msg_iovlen = 1; - ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), SyscallSucceedsWithValue(kDataLength)); // Receive message. @@ -1454,7 +1510,7 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); received_msg.msg_control = &received_cmsgbuf[0]; received_msg.msg_controllen = received_cmsgbuf.size(); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), SyscallSucceedsWithValue(kDataLength)); struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); @@ -1473,8 +1529,9 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) { TEST_P(UdpSocketTest, SendAndReceiveTOS) { // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); // Allow socket to receive control message. int recv_level = SOL_IP; @@ -1484,9 +1541,9 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { recv_type = IPV6_RECVTCLASS; } int recv_opt = kSockOptOn; - ASSERT_THAT( - setsockopt(s_, recv_level, recv_type, &recv_opt, sizeof(recv_opt)), - SyscallSucceeds()); + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, + sizeof(recv_opt)), + SyscallSucceeds()); // Prepare message to send. constexpr size_t kDataLength = 1024; @@ -1517,7 +1574,7 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { sent_cmsg->cmsg_type = sent_type; *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; - ASSERT_THAT(RetryEINTR(sendmsg)(t_, &sent_msg, 0), + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), SyscallSucceedsWithValue(kDataLength)); // Receive message. @@ -1531,7 +1588,7 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); received_msg.msg_control = &received_cmsgbuf[0]; received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - ASSERT_THAT(RetryEINTR(recvmsg)(s_, &received_msg, 0), + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), SyscallSucceedsWithValue(kDataLength)); struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); @@ -1547,27 +1604,29 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) { TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { // Discover minimum buffer size by setting it to zero. constexpr int kRcvBufSz = 0; - ASSERT_THAT( - setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), - SyscallSucceeds()); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); int min = 0; socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), SyscallSucceeds()); - // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + // Bind bind_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); { // Send data of size min and verify that it's received. std::vector<char> buf(min); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); } { @@ -1576,30 +1635,32 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { // is currently empty. std::vector<char> buf(min + 1); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); } } // Test that receive buffer limits are enforced. TEST_P(UdpSocketTest, RecvBufLimits) { // Bind s_ to loopback. - ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_NO_ERRNO(BindLoopback()); int min = 0; { // Discover minimum buffer size by trying to set it to zero. constexpr int kRcvBufSz = 0; - ASSERT_THAT( - setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, sizeof(kRcvBufSz)), - SyscallSucceeds()); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &min, &min_len), + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), SyscallSucceeds()); } @@ -1610,53 +1671,111 @@ TEST_P(UdpSocketTest, RecvBufLimits) { new_rcv_buf_sz = min * 2; } - ASSERT_THAT(setsockopt(s_, SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, sizeof(new_rcv_buf_sz)), SyscallSucceeds()); int rcv_buf_sz = 0; { socklen_t rcv_buf_len = sizeof(rcv_buf_sz); - ASSERT_THAT( - getsockopt(s_, SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, &rcv_buf_len), - SyscallSucceeds()); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, + &rcv_buf_len), + SyscallSucceeds()); } { std::vector<char> buf(min); RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); int sent = 4; if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { // Linux seems to drop the 4th packet even though technically it should // fit in the receive buffer. - ASSERT_THAT(sendto(t_, buf.data(), buf.size(), 0, addr_[0], addrlen_), - SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); sent++; } for (int i = 0; i < sent - 1; i++) { // Receive the data. std::vector<char> received(buf.size()); - EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); } // The last receive should fail with EAGAIN as the last packet should have // been dropped due to lack of space in the receive buffer. std::vector<char> received(buf.size()); - EXPECT_THAT(recv(s_, received.data(), received.size(), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); } } +#ifndef __fuchsia__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + +#endif // __fuchsia__ + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h index 2fd79d99e..f7e25c805 100644 --- a/test/syscalls/linux/udp_socket_test_cases.h +++ b/test/syscalls/linux/udp_socket_test_cases.h @@ -15,8 +15,12 @@ #ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ #define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ +#include <sys/socket.h> + #include "gtest/gtest.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" namespace gvisor { namespace testing { @@ -32,42 +36,46 @@ class UdpSocketTest // Creates two sockets that will be used by test cases. void SetUp() override; - // Closes the sockets created by SetUp(). - void TearDown() override { - EXPECT_THAT(close(s_), SyscallSucceeds()); - EXPECT_THAT(close(t_), SyscallSucceeds()); + // Binds the socket bind_ to the loopback and updates bind_addr_. + PosixError BindLoopback(); - for (size_t i = 0; i < ABSL_ARRAYSIZE(ports_); ++i) { - ASSERT_NO_ERRNO(FreeAvailablePort(ports_[i])); - } - } + // Binds the socket bind_ to Any and updates bind_addr_. + PosixError BindAny(); - // First UDP socket. - int s_; + // Binds given socket to address addr and updates. + PosixError BindSocket(int socket, struct sockaddr* addr); - // Second UDP socket. - int t_; + // Return initialized Any address to port 0. + struct sockaddr_storage InetAnyAddr(); - // The length of the socket address. - socklen_t addrlen_; + // Return initialized Loopback address to port 0. + struct sockaddr_storage InetLoopbackAddr(); - // Initialized address pointing to loopback and port TestPort+i. - struct sockaddr* addr_[3]; + // Disconnects socket sockfd. + void Disconnect(int sockfd); - // Initialize "any" address. - struct sockaddr* anyaddr_; + // Get family for the test. + int GetFamily(); - // Used ports. - int ports_[3]; + // Socket used by Bind methods + FileDescriptor bind_; - private: - // Storage for the loopback addresses. - struct sockaddr_storage addr_storage_[3]; + // Second socket used for tests. + FileDescriptor sock_; - // Storage for the "any" address. - struct sockaddr_storage anyaddr_storage_; -}; + // Address for bind_ socket. + struct sockaddr* bind_addr_; + // Initialized to the length based on GetFamily(). + socklen_t addrlen_; + + // Storage for bind_addr_. + struct sockaddr_storage bind_addr_storage_; + + private: + // Helper to initialize addrlen_ for the test case. + socklen_t GetAddrLength(); +}; } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 052781445..5418948fe 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -125,12 +125,12 @@ PosixErrorOr<struct stat> Fstat(int fd) { PosixErrorOr<bool> Exists(absl::string_view path) { struct stat stat_buf; - int res = stat(std::string(path).c_str(), &stat_buf); + int res = lstat(std::string(path).c_str(), &stat_buf); if (res < 0) { if (errno == ENOENT) { return false; } - return PosixError(errno, absl::StrCat("stat ", path)); + return PosixError(errno, absl::StrCat("lstat ", path)); } return true; } diff --git a/test/util/fs_util.h b/test/util/fs_util.h index caf19b24d..8cdac23a1 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -44,9 +44,14 @@ PosixErrorOr<std::string> GetCWD(); // can't be determined. PosixErrorOr<bool> Exists(absl::string_view path); -// Returns a stat structure for the given path or an error. +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will be traversed. PosixErrorOr<struct stat> Stat(absl::string_view path); +// Returns a stat structure for the given path or an error. If the path +// represents a symlink, it will not be traversed. +PosixErrorOr<struct stat> Lstat(absl::string_view path); + // Returns a stat struct for the given fd. PosixErrorOr<struct stat> Fstat(int fd); diff --git a/test/util/temp_path.cc b/test/util/temp_path.cc index 9c10b6674..e1bdee7fd 100644 --- a/test/util/temp_path.cc +++ b/test/util/temp_path.cc @@ -56,7 +56,7 @@ void TryDeleteRecursively(std::string const& path) { if (undeleted_dirs || undeleted_files || !status.ok()) { std::cerr << path << ": failed to delete " << undeleted_dirs << " directories and " << undeleted_files - << " files: " << status; + << " files: " << status << std::endl; } } } diff --git a/test/util/test_util.cc b/test/util/test_util.cc index b20758626..d0c1d6426 100644 --- a/test/util/test_util.cc +++ b/test/util/test_util.cc @@ -40,15 +40,15 @@ namespace gvisor { namespace testing { -#define TEST_ON_GVISOR "TEST_ON_GVISOR" -#define GVISOR_NETWORK "GVISOR_NETWORK" -#define GVISOR_VFS "GVISOR_VFS" +constexpr char kGvisorNetwork[] = "GVISOR_NETWORK"; +constexpr char kGvisorVfs[] = "GVISOR_VFS"; +constexpr char kFuseEnabled[] = "FUSE_ENABLED"; bool IsRunningOnGvisor() { return GvisorPlatform() != Platform::kNative; } const std::string GvisorPlatform() { // Set by runner.go. - const char* env = getenv(TEST_ON_GVISOR); + const char* env = getenv(kTestOnGvisor); if (!env) { return Platform::kNative; } @@ -56,12 +56,12 @@ const std::string GvisorPlatform() { } bool IsRunningWithHostinet() { - const char* env = getenv(GVISOR_NETWORK); + const char* env = getenv(kGvisorNetwork); return env && strcmp(env, "host") == 0; } bool IsRunningWithVFS1() { - const char* env = getenv(GVISOR_VFS); + const char* env = getenv(kGvisorVfs); if (env == nullptr) { // If not set, it's running on Linux. return false; @@ -69,6 +69,11 @@ bool IsRunningWithVFS1() { return strcmp(env, "VFS1") == 0; } +bool IsFUSEEnabled() { + const char* env = getenv(kFuseEnabled); + return env && strcmp(env, "TRUE") == 0; +} + // Inline cpuid instruction. Preserve %ebx/%rbx register. In PIC compilations // %ebx contains the address of the global offset table. %rbx is occasionally // used to address stack variables in presence of dynamic allocas. diff --git a/test/util/test_util.h b/test/util/test_util.h index 8e3245b27..373c54f32 100644 --- a/test/util/test_util.h +++ b/test/util/test_util.h @@ -195,6 +195,8 @@ namespace gvisor { namespace testing { +constexpr char kTestOnGvisor[] = "TEST_ON_GVISOR"; + // TestInit must be called prior to RUN_ALL_TESTS. // // This parses all arguments and adjusts argc and argv appropriately. @@ -215,12 +217,15 @@ namespace Platform { constexpr char kNative[] = "native"; constexpr char kPtrace[] = "ptrace"; constexpr char kKVM[] = "kvm"; +constexpr char kFuchsia[] = "fuchsia"; } // namespace Platform bool IsRunningOnGvisor(); const std::string GvisorPlatform(); bool IsRunningWithHostinet(); +// TODO(gvisor.dev/issue/1624): Delete once VFS1 is gone. bool IsRunningWithVFS1(); +bool IsFUSEEnabled(); #ifdef __linux__ void SetupGvisorDeathTest(); @@ -563,6 +568,25 @@ ssize_t ApplyFileIoSyscall(F const& f, size_t const count) { } // namespace internal +inline PosixErrorOr<std::string> ReadAllFd(int fd) { + std::string all; + all.reserve(128 * 1024); // arbitrary. + + std::vector<char> buffer(16 * 1024); + for (;;) { + auto const bytes = RetryEINTR(read)(fd, buffer.data(), buffer.size()); + if (bytes < 0) { + return PosixError(errno, "file read"); + } + if (bytes == 0) { + return std::move(all); + } + if (bytes > 0) { + all.append(buffer.data(), bytes); + } + } +} + inline ssize_t ReadFd(int fd, void* buf, size_t count) { return internal::ApplyFileIoSyscall( [&](size_t completed) { diff --git a/tools/bazel.mk b/tools/bazel.mk index 9f4a40669..82fe11a03 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -15,6 +15,7 @@ # limitations under the License. # See base Makefile. +SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ xargs -n 1 basename 2>/dev/null) @@ -22,22 +23,41 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ # Bazel container configuration (see below). USER ?= gvisor HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +BUILDER_BASE := gvisor.dev/images/default +BUILDER_IMAGE := gvisor.dev/images/builder +BUILDER_NAME ?= gvisor-builder-$(HASH) DOCKER_NAME ?= gvisor-bazel-$(HASH) DOCKER_PRIVILEGED ?= --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock -# Non-configurable. +# Bazel flags. +OPTIONS += --test_output=errors --keep_going --verbose_failures=true +BAZEL := bazel $(STARTUP_OPTIONS) + +# Basic options. UID := $(shell id -u ${USER}) GID := $(shell id -g ${USER}) USERADD_OPTIONS := FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS) +FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID) +FULL_DOCKER_RUN_OPTIONS += --entrypoint "" +FULL_DOCKER_RUN_OPTIONS += --init FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)" FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)" FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp" +FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID) +FULL_DOCKER_EXEC_OPTIONS += --interactive +ifeq (true,$(shell [[ -t 0 ]] && echo true)) +FULL_DOCKER_EXEC_OPTIONS += --tty +endif + +# Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" +FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) +FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) ifneq ($(GID),$(DOCKER_GROUP)) USERADD_OPTIONS += --groups $(DOCKER_GROUP) @@ -45,7 +65,40 @@ GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) && FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP) endif endif -SHELL=/bin/bash -o pipefail + +# Add KVM passthrough options. +ifneq (,$(wildcard /dev/kvm)) +FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm +KVM_GROUP := $(shell stat -c '%g' /dev/kvm) +ifneq ($(GID),$(KVM_GROUP)) +USERADD_OPTIONS += --groups $(KVM_GROUP) +GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) && +FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP) +endif +endif + +# Load the appropriate config. +ifneq (,$(BAZEL_CONFIG)) +OPTIONS += --config=$(BAZEL_CONFIG) +endif + +# NOTE: we pass -l to useradd below because otherwise you can hit a bug +# best described here: +# https://github.com/moby/moby/issues/5419#issuecomment-193876183 +# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out +# out of disk space. +bazel-image: load-default + @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi + docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ + $(BUILDER_BASE) \ + sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ + $(GROUPADD_DOCKER) \ + useradd -l --uid $(UID) --non-unique --no-create-home \ + --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" + docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) + @docker rm -f $(BUILDER_NAME) +.PHONY: bazel-image ## ## Bazel helpers. @@ -60,40 +113,37 @@ SHELL=/bin/bash -o pipefail ## GCLOUD_CONFIG - The gcloud config directory (detect: detected). ## DOCKER_SOCKET - The Docker socket (default: detected). ## -bazel-server-start: load-default ## Starts the bazel server. +bazel-server-start: bazel-image ## Starts the bazel server. @mkdir -p $(BAZEL_CACHE) @mkdir -p $(GCLOUD_CONFIG) - docker run -d --rm \ - --init \ - --name $(DOCKER_NAME) \ - --user 0:0 $(DOCKER_GROUP_OPTIONS) \ + @if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi + # This command runs a bazel server, and the container sticks around + # until the bazel server exits. This should ensure that it does not + # exit in the middle of running a build, but also it won't stick around + # forever. The build commands wrap around an appropriate exec into the + # container in order to perform work via the bazel client. + docker run -d --rm --name $(DOCKER_NAME) \ -v "$(CURDIR):$(CURDIR)" \ --workdir "$(CURDIR)" \ - --entrypoint "" \ $(FULL_DOCKER_RUN_OPTIONS) \ - gvisor.dev/images/default \ - sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ - $(GROUPADD_DOCKER) \ - useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ - bazel version && \ - exec tail --pid=\$$(bazel info server_pid) -f /dev/null" - @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \ - if ! docker ps | grep $(DOCKER_NAME); then exit 1; else sleep 1; fi; done + $(BUILDER_IMAGE) \ + sh -c "tail -f --pid=\$$($(BAZEL) info server_pid)" .PHONY: bazel-server-start bazel-shutdown: ## Shuts down a running bazel server. - @docker exec --user $(UID):$(GID) $(DOCKER_NAME) bazel shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \ + rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]] .PHONY: bazel-shutdown bazel-alias: ## Emits an alias that can be used within the shell. - @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'" + @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'" .PHONY: bazel-alias bazel-server: ## Ensures that the server exists. Used as an internal target. - @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start .PHONY: bazel-server -build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c 'bazel $(STARTUP_OPTIONS) build $(OPTIONS) $(TARGETS)' +build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)' build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ @@ -109,7 +159,7 @@ copy: bazel-server ifeq (,$(DESTINATION)) $(error Destination not provided.) endif - @$(call build_paths,cp -a {} $(DESTINATION)) + @$(call build_paths,cp -fa {} $(DESTINATION)) run: bazel-server @$(call build_paths,{} $(ARGS)) @@ -120,5 +170,9 @@ sudo: bazel-server .PHONY: sudo test: bazel-server - @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel $(STARTUP_OPTIONS) test $(OPTIONS) $(TARGETS) + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS) .PHONY: test + +query: bazel-server + @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) query $(OPTIONS) '$(TARGETS)' +.PHONY: query diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index f2f80bae1..3f809065d 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -49,3 +49,40 @@ rbe_toolchain( toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", ) + +# Updated versions of the above, compatible with bazel3. +rbe_platform( + name = "rbe_ubuntu1604_bazel3", + constraint_values = [ + "@bazel_tools//platforms:x86_64", + "@bazel_tools//platforms:linux", + "@bazel_tools//tools/cpp:clang", + "@bazel_toolchains_bazel3//constraints:xenial", + "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan", + ], + remote_execution_properties = """ + properties: { + name: "container-image" + value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272" + } + properties: { + name: "dockerAddCapabilities" + value: "SYS_ADMIN" + } + properties: { + name: "dockerPrivileged" + value: "true" + } + """, +) + +rbe_toolchain( + name = "cc-toolchain-clang-x86_64-default_bazel3", + exec_compatible_with = [], + tags = [ + "manual", + ], + target_compatible_with = [], + toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index 620c460de..3db8e13d0 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -32,6 +32,9 @@ rbe_platform = native.platform rbe_toolchain = native.toolchain vdso_linker_option = "-fuse-ld=gold " +def short_path(path): + return path + def proto_library(name, has_services = None, **kwargs): native.proto_library( name = name, diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index 571e9a6e6..f8def4823 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -88,7 +88,7 @@ const ( testMagic = "// +mustescape:" // exempt is the exemption annotation. - exempt = "// escapes:" + exempt = "// escapes" ) // escapingBuiltins are builtins known to escape. @@ -546,7 +546,7 @@ func run(pass *analysis.Pass) (interface{}, error) { for _, cg := range f.Comments { for _, c := range cg.List { p := pass.Fset.Position(c.Slash) - if strings.HasPrefix(c.Text, exempt) { + if strings.HasPrefix(strings.ToLower(c.Text), exempt) { exemptions[LinePosition{ Filename: filepath.Base(p.Filename), Line: p.Line, diff --git a/tools/defs.bzl b/tools/defs.bzl index 40afcdb79..e35e29634 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,7 +7,7 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") load("//tools/nogo:defs.bzl", "nogo_test") @@ -38,6 +38,7 @@ py_requirement = _py_requirement py_test = _py_test select_arch = _select_arch select_system = _select_system +short_path = _short_path rbe_platform = _rbe_platform rbe_toolchain = _rbe_toolchain vdso_linker_option = _vdso_linker_option diff --git a/tools/go_generics/BUILD b/tools/go_generics/BUILD index 32a949c93..558826bf1 100644 --- a/tools/go_generics/BUILD +++ b/tools/go_generics/BUILD @@ -12,27 +12,3 @@ go_binary( visibility = ["//:sandbox"], deps = ["//tools/go_generics/globals"], ) - -genrule( - name = "go_generics_tests", - srcs = glob(["generics_tests/**"]) + [":go_generics"], - outs = ["go_generics_tests.tgz"], - cmd = "tar -czvhf $@ $(SRCS)", -) - -genrule( - name = "go_generics_test_bundle", - srcs = [ - ":go_generics_tests.tgz", - ":go_generics_unittest.sh", - ], - outs = ["go_generics_test.sh"], - cmd = "cat $(location :go_generics_unittest.sh) $(location :go_generics_tests.tgz) > $@", - executable = True, -) - -sh_test( - name = "go_generics_test", - size = "small", - srcs = ["go_generics_test.sh"], -) diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index 8c9995fd4..33329cf28 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -1,11 +1,24 @@ +"""Generics support via go_generics.""" + +TemplateInfo = provider( + fields = { + "types": "required types", + "opt_types": "optional types", + "consts": "required consts", + "opt_consts": "optional consts", + "deps": "package dependencies", + "file": "merged template", + }, +) + def _go_template_impl(ctx): - input = ctx.files.srcs + srcs = ctx.files.srcs output = ctx.outputs.out - args = ["-o=%s" % output.path] + [f.path for f in input] + args = ["-o=%s" % output.path] + [f.path for f in srcs] ctx.actions.run( - inputs = input, + inputs = srcs, outputs = [output], mnemonic = "GoGenericsTemplate", progress_message = "Building Go template %s" % ctx.label, @@ -13,14 +26,14 @@ def _go_template_impl(ctx): executable = ctx.executable._tool, ) - return struct( + return [TemplateInfo( types = ctx.attr.types, opt_types = ctx.attr.opt_types, consts = ctx.attr.consts, opt_consts = ctx.attr.opt_consts, deps = ctx.attr.deps, file = output, - ) + )] """ Generates a Go template from a set of Go files. @@ -43,7 +56,7 @@ go_template = rule( implementation = _go_template_impl, attrs = { "srcs": attr.label_list(mandatory = True, allow_files = True), - "deps": attr.label_list(allow_files = True), + "deps": attr.label_list(allow_files = True, cfg = "target"), "types": attr.string_list(), "opt_types": attr.string_list(), "consts": attr.string_list(), @@ -55,8 +68,14 @@ go_template = rule( }, ) +TemplateInstanceInfo = provider( + fields = { + "srcs": "source files", + }, +) + def _go_template_instance_impl(ctx): - template = ctx.attr.template + template = ctx.attr.template[TemplateInfo] output = ctx.outputs.out # Check that all required types are defined. @@ -81,20 +100,21 @@ def _go_template_instance_impl(ctx): # Build the argument list. args = ["-i=%s" % template.file.path, "-o=%s" % output.path] - args += ["-p=%s" % ctx.attr.package] + if ctx.attr.package: + args.append("-p=%s" % ctx.attr.package) if len(ctx.attr.prefix) > 0: - args += ["-prefix=%s" % ctx.attr.prefix] + args.append("-prefix=%s" % ctx.attr.prefix) if len(ctx.attr.suffix) > 0: - args += ["-suffix=%s" % ctx.attr.suffix] + args.append("-suffix=%s" % ctx.attr.suffix) args += [("-t=%s=%s" % (p[0], p[1])) for p in ctx.attr.types.items()] args += [("-c=%s=%s" % (p[0], p[1])) for p in ctx.attr.consts.items()] args += [("-import=%s=%s" % (p[0], p[1])) for p in ctx.attr.imports.items()] if ctx.attr.anon: - args += ["-anon"] + args.append("-anon") ctx.actions.run( inputs = [template.file], @@ -105,9 +125,9 @@ def _go_template_instance_impl(ctx): executable = ctx.executable._tool, ) - return struct( - files = depset([output]), - ) + return [TemplateInstanceInfo( + srcs = [output], + )] """ Instantiates a Go template by replacing all generic types with concrete ones. @@ -125,14 +145,14 @@ Args: go_template_instance = rule( implementation = _go_template_instance_impl, attrs = { - "template": attr.label(mandatory = True, providers = ["types"]), + "template": attr.label(mandatory = True), "prefix": attr.string(), "suffix": attr.string(), "types": attr.string_dict(), "consts": attr.string_dict(), "imports": attr.string_dict(), "anon": attr.bool(mandatory = False, default = False), - "package": attr.string(mandatory = True), + "package": attr.string(mandatory = False), "out": attr.output(mandatory = True), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), }, diff --git a/tools/go_generics/generics_tests/all_stmts/opts.txt b/tools/go_generics/generics_tests/all_stmts/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_stmts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/all_types/opts.txt b/tools/go_generics/generics_tests/all_types/opts.txt deleted file mode 100644 index c9d0e09bf..000000000 --- a/tools/go_generics/generics_tests/all_types/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q diff --git a/tools/go_generics/generics_tests/anon/opts.txt b/tools/go_generics/generics_tests/anon/opts.txt deleted file mode 100644 index a5e9d26de..000000000 --- a/tools/go_generics/generics_tests/anon/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New -anon diff --git a/tools/go_generics/generics_tests/consts/opts.txt b/tools/go_generics/generics_tests/consts/opts.txt deleted file mode 100644 index 4fb59dce8..000000000 --- a/tools/go_generics/generics_tests/consts/opts.txt +++ /dev/null @@ -1 +0,0 @@ --c=c1=20 -c=z=600 -c=v=3.3 -c=s="def" -c=A=20 -c=C=100 -c=S="def" -c=T="ABC" diff --git a/tools/go_generics/generics_tests/imports/opts.txt b/tools/go_generics/generics_tests/imports/opts.txt deleted file mode 100644 index 87324be79..000000000 --- a/tools/go_generics/generics_tests/imports/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=sync.Mutex -c=n=math.Uint32 -c=m=math.Uint64 -import=sync=sync -import=math=mymathpath diff --git a/tools/go_generics/generics_tests/remove_typedef/opts.txt b/tools/go_generics/generics_tests/remove_typedef/opts.txt deleted file mode 100644 index 9c8ecaada..000000000 --- a/tools/go_generics/generics_tests/remove_typedef/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=U diff --git a/tools/go_generics/generics_tests/simple/opts.txt b/tools/go_generics/generics_tests/simple/opts.txt deleted file mode 100644 index 7832ef66f..000000000 --- a/tools/go_generics/generics_tests/simple/opts.txt +++ /dev/null @@ -1 +0,0 @@ --t=T=Q -suffix=New diff --git a/tools/go_generics/go_generics_unittest.sh b/tools/go_generics/go_generics_unittest.sh deleted file mode 100755 index 44b22db91..000000000 --- a/tools/go_generics/go_generics_unittest.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Bash "safe-mode": Treat command failures as fatal (even those that occur in -# pipes), and treat unset variables as errors. -set -eu -o pipefail - -# This file will be generated as a self-extracting shell script in order to -# eliminate the need for any runtime dependencies. The tarball at the end will -# include the go_generics binary, as well as a subdirectory named -# generics_tests. See the BUILD file for more information. -declare -r temp=$(mktemp -d) -function cleanup() { - rm -rf "${temp}" -} -# trap cleanup EXIT - -# Print message in "$1" then exit with status 1. -function die () { - echo "$1" 1>&2 - exit 1 -} - -# This prints the line number of __BUNDLE__ below, that should be the last line -# of this script. After that point, the concatenated archive will be the -# contents. -declare -r tgz=`awk '/^__BUNDLE__/ {print NR + 1; exit 0; }' $0` -tail -n+"${tgz}" $0 | tar -xzv -C "${temp}" - -# The target for the test. -declare -r binary="$(find ${temp} -type f -a -name go_generics)" -declare -r input_dirs="$(find ${temp} -type d -a -name generics_tests)/*" - -# Go through all test cases. -for f in ${input_dirs}; do - base=$(basename "${f}") - - # Run go_generics on the input file. - opts=$(head -n 1 ${f}/opts.txt) - out="${f}/output/generated.go" - expected="${f}/output/output.go" - ${binary} ${opts} "-i=${f}/input.go" "-o=${out}" || die "go_generics failed for test case \"${base}\"" - - # Compare the outputs. - diff ${expected} ${out} - if [ $? -ne 0 ]; then - echo "Expected:" - cat ${expected} - echo "Actual:" - cat ${out} - die "Actual output is different from expected for test \"${base}\"" - fi -done - -echo "PASS" -exit 0 -__BUNDLE__ diff --git a/tools/go_generics/tests/BUILD b/tools/go_generics/tests/BUILD new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/tools/go_generics/tests/BUILD diff --git a/tools/go_generics/tests/all_stmts/BUILD b/tools/go_generics/tests/all_stmts/BUILD new file mode 100644 index 000000000..a4a7c775a --- /dev/null +++ b/tools/go_generics/tests/all_stmts/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_stmts", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_stmts/input.go b/tools/go_generics/tests/all_stmts/input.go index 4791d1ff1..4791d1ff1 100644 --- a/tools/go_generics/generics_tests/all_stmts/input.go +++ b/tools/go_generics/tests/all_stmts/input.go diff --git a/tools/go_generics/generics_tests/all_stmts/output/output.go b/tools/go_generics/tests/all_stmts/output.go index a53d84535..a53d84535 100644 --- a/tools/go_generics/generics_tests/all_stmts/output/output.go +++ b/tools/go_generics/tests/all_stmts/output.go diff --git a/tools/go_generics/tests/all_types/BUILD b/tools/go_generics/tests/all_types/BUILD new file mode 100644 index 000000000..60b1fd314 --- /dev/null +++ b/tools/go_generics/tests/all_types/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "all_types", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/all_types/input.go b/tools/go_generics/tests/all_types/input.go index 3575d02ec..6f85bbb69 100644 --- a/tools/go_generics/generics_tests/all_types/input.go +++ b/tools/go_generics/tests/all_types/input.go @@ -14,7 +14,9 @@ package tests -import "./lib" +import ( + "./lib" +) type T int diff --git a/tools/go_generics/generics_tests/all_types/lib/lib.go b/tools/go_generics/tests/all_types/lib/lib.go index 988786496..988786496 100644 --- a/tools/go_generics/generics_tests/all_types/lib/lib.go +++ b/tools/go_generics/tests/all_types/lib/lib.go diff --git a/tools/go_generics/generics_tests/all_types/output/output.go b/tools/go_generics/tests/all_types/output.go index 41fd147a1..c0bbebfe7 100644 --- a/tools/go_generics/generics_tests/all_types/output/output.go +++ b/tools/go_generics/tests/all_types/output.go @@ -14,7 +14,9 @@ package main -import "./lib" +import ( + "./lib" +) type newType struct { a Q diff --git a/tools/go_generics/tests/anon/BUILD b/tools/go_generics/tests/anon/BUILD new file mode 100644 index 000000000..ef24f4b25 --- /dev/null +++ b/tools/go_generics/tests/anon/BUILD @@ -0,0 +1,18 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "anon", + anon = True, + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/anon/input.go b/tools/go_generics/tests/anon/input.go index 44086d522..44086d522 100644 --- a/tools/go_generics/generics_tests/anon/input.go +++ b/tools/go_generics/tests/anon/input.go diff --git a/tools/go_generics/generics_tests/anon/output/output.go b/tools/go_generics/tests/anon/output.go index 160cddf79..7fa791853 100644 --- a/tools/go_generics/generics_tests/anon/output/output.go +++ b/tools/go_generics/tests/anon/output.go @@ -35,8 +35,8 @@ func (f FooNew) GetBar(name string) Q { func foobarNew() { a := BazNew{} - a.Q = 0 // should not be renamed, this is a limitation + a.Q = 0 b := otherpkg.UnrelatedType{} - b.Q = 0 // should not be renamed, this is a limitation + b.Q = 0 } diff --git a/tools/go_generics/tests/consts/BUILD b/tools/go_generics/tests/consts/BUILD new file mode 100644 index 000000000..fd7caccad --- /dev/null +++ b/tools/go_generics/tests/consts/BUILD @@ -0,0 +1,23 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "consts", + consts = { + "c1": "20", + "z": "600", + "v": "3.3", + "s": "\"def\"", + "A": "20", + "C": "100", + "S": "\"def\"", + "T": "\"ABC\"", + }, + inputs = ["input.go"], + output = "output.go", +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/consts/input.go b/tools/go_generics/tests/consts/input.go index 04b95fcc6..04b95fcc6 100644 --- a/tools/go_generics/generics_tests/consts/input.go +++ b/tools/go_generics/tests/consts/input.go diff --git a/tools/go_generics/generics_tests/consts/output/output.go b/tools/go_generics/tests/consts/output.go index 18d316cc9..18d316cc9 100644 --- a/tools/go_generics/generics_tests/consts/output/output.go +++ b/tools/go_generics/tests/consts/output.go diff --git a/tools/go_generics/tests/defs.bzl b/tools/go_generics/tests/defs.bzl new file mode 100644 index 000000000..6277c3947 --- /dev/null +++ b/tools/go_generics/tests/defs.bzl @@ -0,0 +1,67 @@ +"""Generics tests.""" + +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") + +def _go_generics_test_impl(ctx): + runner = ctx.actions.declare_file(ctx.label.name) + runner_content = "\n".join([ + "#!/bin/bash", + "exec diff --ignore-blank-lines --ignore-matching-lines=^[[:space:]]*// %s %s" % ( + ctx.files.template_output[0].short_path, + ctx.files.expected_output[0].short_path, + ), + "", + ]) + ctx.actions.write(runner, runner_content, is_executable = True) + return [DefaultInfo( + executable = runner, + runfiles = ctx.runfiles( + files = ctx.files.template_output + ctx.files.expected_output, + collect_default = True, + collect_data = True, + ), + )] + +_go_generics_test = rule( + implementation = _go_generics_test_impl, + attrs = { + "template_output": attr.label(mandatory = True, allow_single_file = True), + "expected_output": attr.label(mandatory = True, allow_single_file = True), + }, + test = True, +) + +def go_generics_test(name, inputs, output, types = None, consts = None, **kwargs): + """Instantiates a generics test. + + Args: + name: the name of the test. + inputs: all the input files. + output: the output files. + types: the template types (dictionary). + consts: the template consts (dictionary). + **kwargs: additional arguments for the template_instance. + """ + if types == None: + types = dict() + if consts == None: + consts = dict() + go_template( + name = name + "_template", + srcs = inputs, + types = types.keys(), + consts = consts.keys(), + ) + go_template_instance( + name = name + "_output", + template = ":" + name + "_template", + out = name + "_output.go", + types = types, + consts = consts, + **kwargs + ) + _go_generics_test( + name = name + "_test", + template_output = name + "_output.go", + expected_output = output, + ) diff --git a/tools/go_generics/tests/imports/BUILD b/tools/go_generics/tests/imports/BUILD new file mode 100644 index 000000000..a86223d41 --- /dev/null +++ b/tools/go_generics/tests/imports/BUILD @@ -0,0 +1,24 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "imports", + consts = { + "n": "math.Uint32", + "m": "math.Uint64", + }, + imports = { + "sync": "sync", + "math": "mymathpath", + }, + inputs = ["input.go"], + output = "output.go", + types = { + "T": "sync.Mutex", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/imports/input.go b/tools/go_generics/tests/imports/input.go index 0f032c2a1..0f032c2a1 100644 --- a/tools/go_generics/generics_tests/imports/input.go +++ b/tools/go_generics/tests/imports/input.go diff --git a/tools/go_generics/generics_tests/imports/output/output.go b/tools/go_generics/tests/imports/output.go index 2488ca58c..2488ca58c 100644 --- a/tools/go_generics/generics_tests/imports/output/output.go +++ b/tools/go_generics/tests/imports/output.go diff --git a/tools/go_generics/tests/remove_typedef/BUILD b/tools/go_generics/tests/remove_typedef/BUILD new file mode 100644 index 000000000..46457cec6 --- /dev/null +++ b/tools/go_generics/tests/remove_typedef/BUILD @@ -0,0 +1,16 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "remove_typedef", + inputs = ["input.go"], + output = "output.go", + types = { + "T": "U", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/remove_typedef/input.go b/tools/go_generics/tests/remove_typedef/input.go index cf632bae7..cf632bae7 100644 --- a/tools/go_generics/generics_tests/remove_typedef/input.go +++ b/tools/go_generics/tests/remove_typedef/input.go diff --git a/tools/go_generics/generics_tests/remove_typedef/output/output.go b/tools/go_generics/tests/remove_typedef/output.go index d44fd8e1c..d44fd8e1c 100644 --- a/tools/go_generics/generics_tests/remove_typedef/output/output.go +++ b/tools/go_generics/tests/remove_typedef/output.go diff --git a/tools/go_generics/tests/simple/BUILD b/tools/go_generics/tests/simple/BUILD new file mode 100644 index 000000000..4b9265ea4 --- /dev/null +++ b/tools/go_generics/tests/simple/BUILD @@ -0,0 +1,17 @@ +load("//tools/go_generics/tests:defs.bzl", "go_generics_test") + +go_generics_test( + name = "simple", + inputs = ["input.go"], + output = "output.go", + suffix = "New", + types = { + "T": "Q", + }, +) + +# @unused +glaze_ignore = [ + "input.go", + "output.go", +] diff --git a/tools/go_generics/generics_tests/simple/input.go b/tools/go_generics/tests/simple/input.go index 2a917f16c..2a917f16c 100644 --- a/tools/go_generics/generics_tests/simple/input.go +++ b/tools/go_generics/tests/simple/input.go diff --git a/tools/go_generics/generics_tests/simple/output/output.go b/tools/go_generics/tests/simple/output.go index 6bfa0b25b..6bfa0b25b 100644 --- a/tools/go_generics/generics_tests/simple/output/output.go +++ b/tools/go_generics/tests/simple/output.go diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 4886efddf..68d759083 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -9,11 +9,9 @@ automatically generating code to marshal go data structures to memory. `binary.Marshal` by moving the go runtime reflection necessary to marshal a struct to compile-time. -`go_marshal` automatically generates implementations for `abi.Marshallable` and -`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall -implementations) can directly invoke `safemem.Reader.ReadToBlocks` and -`safemem.Writer.WriteFromBlocks`. Data structures that require custom -serialization will have manual implementations for these interfaces. +`go_marshal` automatically generates implementations for `marshal.Marshallable` +and `safemem.{Reader,Writer}`. Data structures that require custom serialization +will have manual implementations for these interfaces. Data structures can be flagged for code generation by adding a struct-level comment `// +marshal`. diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 177013dbb..19bcd4e6a 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -413,13 +413,13 @@ func (g *Generator) Run() error { for _, t := range g.collectMarshallableTypes(a, fsets[i]) { impl := g.generateOne(t, fsets[i]) // Collect Marshallable types referenced by the generated code. - for ref, _ := range impl.ms { + for ref := range impl.ms { ms[ref] = struct{}{} } impls = append(impls, impl) // Collect imports referenced by the generated code and add them to // the list of imports we need to copy to the generated code. - for name, _ := range impl.is { + for name := range impl.is { if !g.imports.markUsed(name) { panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 9cd3c9579..4b9cea08a 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -268,6 +268,10 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) + g.emit("%s.MarshalBytes(dst)\n", g.r) + } if thisPacked { g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") @@ -277,16 +281,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) }) g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.MarshalBytes(dst)\n", g.r) - }) + g.inIndent(fallback) g.emit("}\n") } else { g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r) } } else { - g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + fallback() } }) g.emit("}\n\n") @@ -294,6 +295,10 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) g.inIndent(func() { + fallback := func() { + g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) + g.emit("%s.UnmarshalBytes(src)\n", g.r) + } if thisPacked { g.recordUsedImport("safecopy") g.recordUsedImport("unsafe") @@ -303,16 +308,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) }) g.emit("} else {\n") - g.inIndent(func() { - g.emit("%s.UnmarshalBytes(src)\n", g.r) - }) + g.inIndent(fallback) g.emit("}\n") } else { g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r) } } else { - g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + fallback() } }) g.emit("}\n\n") @@ -463,8 +465,10 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, }) g.emit("}\n\n") - g.emit("// Handle any final partial object.\n") - g.emit("if length < size*count && length%size != 0 {\n") + g.emit("// Handle any final partial object. buf is guaranteed to be long enough for the\n") + g.emit("// final element, but may not contain valid data for the entire range. This may\n") + g.emit("// result in unmarshalling zero values for some parts of the object.\n") + g.emit("if length%size != 0 {\n") g.inIndent(func() { g.emit("idx := limit\n") g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go index cb2166252..85b196f08 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/tools/go_marshal/marshal/marshal.go @@ -58,18 +58,12 @@ type Marshallable interface { // likely make use of the type of these fields). SizeBytes() int - // MarshalBytes serializes a copy of a type to dst. dst may be smaller than - // SizeBytes(), which results in a part of the struct being marshalled. Note - // that this may have unexpected results for non-packed types, as implicit - // padding needs to be taken into account when reasoning about how much of - // the type is serialized. + // MarshalBytes serializes a copy of a type to dst. + // Precondition: dst must be at least SizeBytes() in length. MarshalBytes(dst []byte) - // UnmarshalBytes deserializes a type from src. src may be smaller than - // SizeBytes(), which results in a partially deserialized struct. Note that - // this may have unexpected results for non-packed types, as implicit - // padding needs to be taken into account when reasoning about how much of - // the type is deserialized. + // UnmarshalBytes deserializes a type from src. + // Precondition: src must be at least SizeBytes() in length. UnmarshalBytes(src []byte) // Packed returns true if the marshalled size of the type is the same as the @@ -89,8 +83,8 @@ type Marshallable interface { // representation to the dst buffer. This is only safe to do when the type // has no implicit padding, see Marshallable.Packed. When Packed would // return false, MarshalUnsafe should fall back to the safer but slower - // MarshalBytes. dst may be smaller than SizeBytes(), see comment for - // MarshalBytes for implications. + // MarshalBytes. + // Precondition: dst must be at least SizeBytes() in length. MarshalUnsafe(dst []byte) // UnmarshalUnsafe deserializes a type by directly copying to the underlying @@ -99,8 +93,8 @@ type Marshallable interface { // This allows much faster unmarshalling of types which have no implicit // padding, see Marshallable.Packed. When Packed would return false, // UnmarshalUnsafe should fall back to the safer but slower unmarshal - // mechanism implemented in UnmarshalBytes. src may be smaller than - // SizeBytes(), see comment for UnmarshalBytes for implications. + // mechanism implemented in UnmarshalBytes. + // Precondition: src must be at least SizeBytes() in length. UnmarshalUnsafe(src []byte) // CopyIn deserializes a Marshallable type from a task's memory. This may @@ -149,14 +143,16 @@ type Marshallable interface { // // Generates four additional functions for marshalling slices of Foos like this: // -// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It's -// // more efficient that repeatedly calling calling Foo.MarshalUnsafe over a -// // []Foo in a loop. +// // MarshalUnsafeFooSlice is like Foo.MarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.MarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: dst must be at least len(src)*Foo.SizeBytes() in length. // func MarshalUnsafeFooSlice(src []Foo, dst []byte) (int, error) { ... } // -// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It's -// // more efficient that repeatedly calling calling Foo.UnmarshalUnsafe over a -// // []Foo in a loop. +// // UnmarshalUnsafeFooSlice is like Foo.UnmarshalUnsafe, buf for a []Foo. It +// // might be more efficient that repeatedly calling Foo.UnmarshalUnsafe +// // over a []Foo in a loop if the type is Packed. +// // Preconditions: src must be at least len(dst)*Foo.SizeBytes() in length. // func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } // // // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. diff --git a/tools/go_marshal/primitive/primitive.go b/tools/go_marshal/primitive/primitive.go index ebcf130ae..d93edda8b 100644 --- a/tools/go_marshal/primitive/primitive.go +++ b/tools/go_marshal/primitive/primitive.go @@ -17,10 +17,22 @@ package primitive import ( + "io" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/marshal" ) +// Int8 is a marshal.Marshallable implementation for int8. +// +// +marshal slice:Int8Slice:inner +type Int8 int8 + +// Uint8 is a marshal.Marshallable implementation for uint8. +// +// +marshal slice:Uint8Slice:inner +type Uint8 uint8 + // Int16 is a marshal.Marshallable implementation for int16. // // +marshal slice:Int16Slice:inner @@ -51,6 +63,66 @@ type Int64 int64 // +marshal slice:Uint64Slice:inner type Uint64 uint64 +// ByteSlice is a marshal.Marshallable implementation for []byte. +// This is a convenience wrapper around a dynamically sized type, and can't be +// embedded in other marshallable types because it breaks assumptions made by +// go-marshal internals. It violates the "no dynamically-sized types" +// constraint of the go-marshal library. +type ByteSlice []byte + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (b *ByteSlice) SizeBytes() int { + return len(*b) +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (b *ByteSlice) MarshalBytes(dst []byte) { + copy(dst, *b) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (b *ByteSlice) UnmarshalBytes(src []byte) { + copy(*b, src) +} + +// Packed implements marshal.Marshallable.Packed. +func (b *ByteSlice) Packed() bool { + return false +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (b *ByteSlice) MarshalUnsafe(dst []byte) { + b.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (b *ByteSlice) UnmarshalUnsafe(src []byte) { + b.UnmarshalBytes(src) +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyInBytes(addr, *b) +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { + return task.CopyOutBytes(addr, *b) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { + return task.CopyOutBytes(addr, (*b)[:limit]) +} + +// WriteTo implements io.WriterTo.WriteTo. +func (b *ByteSlice) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(*b) + return int64(n), err +} + +var _ marshal.Marshallable = (*ByteSlice)(nil) + // Below, we define some convenience functions for marshalling primitive types // using the newtypes above, without requiring superfluous casts. diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 2fbcc8a03..3d989823a 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -39,6 +39,6 @@ go_test( "//pkg/usermem", "//tools/go_marshal/analysis", "//tools/go_marshal/marshal", - "@com_github_google_go-cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/tools/go_mod.sh b/tools/go_mod.sh deleted file mode 100755 index 84b779d6d..000000000 --- a/tools/go_mod.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# Copyright 2020 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -# Build the :gopath target. -bazel build //:gopath -declare -r gopathdir="bazel-bin/gopath/src/gvisor.dev/gvisor/" - -# Copy go.mod and execute the command. -cp -a go.mod go.sum "${gopathdir}" -(cd "${gopathdir}" && go mod "$@") -cp -a "${gopathdir}/go.mod" "${gopathdir}/go.sum" . - -# Cleanup the WORKSPACE file. -bazel run //:gazelle -- update-repos -from_file=go.mod diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 309ee9c21..4f6ed208a 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -103,7 +103,7 @@ type scanFunctions struct { // skipped if nil. // // Fields tagged nosave are skipped. -func scanFields(ss *ast.StructType, fn scanFunctions) { +func scanFields(ss *ast.StructType, prefix string, fn scanFunctions) { if ss.Fields.List == nil { // No fields. return @@ -127,7 +127,16 @@ func scanFields(ss *ast.StructType, fn scanFunctions) { continue } - switch tag := extractStateTag(field.Tag); tag { + // Is this a anonymous struct? If yes, then continue the + // recursion with the given prefix. We don't pay attention to + // any tags on the top-level struct field. + tag := extractStateTag(field.Tag) + if anon, ok := field.Type.(*ast.StructType); ok && tag == "" { + scanFields(anon, name+".", fn) + continue + } + + switch tag { case "zerovalue": if fn.zerovalue != nil { fn.zerovalue(name) @@ -201,28 +210,12 @@ func main() { // initCalls is dumped at the end. var initCalls []string - // Declare our emission closures. + // Common closures. emitRegister := func(name string) { - initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name)) + initCalls = append(initCalls, fmt.Sprintf("%sRegister((*%s)(nil))", statePrefix, name)) } emitZeroCheck := func(name string) { - fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { m.Failf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, name, name) - } - emitLoadValue := func(name, typName string) { - fmt.Fprintf(outputFile, " m.LoadValue(\"%s\", new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", name, typName, camelCased(name), typName) - } - emitLoad := func(name string) { - fmt.Fprintf(outputFile, " m.Load(\"%s\", &x.%s)\n", name, name) - } - emitLoadWait := func(name string) { - fmt.Fprintf(outputFile, " m.LoadWait(\"%s\", &x.%s)\n", name, name) - } - emitSaveValue := func(name, typName string) { - fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) - fmt.Fprintf(outputFile, " m.SaveValue(\"%s\", %s)\n", name, name) - } - emitSave := func(name string) { - fmt.Fprintf(outputFile, " m.Save(\"%s\", &x.%s)\n", name, name) + fmt.Fprintf(outputFile, " if !%sIsZeroValue(&x.%s) { %sFailf(\"%s is %%#v, expected zero\", &x.%s) }\n", statePrefix, name, statePrefix, name, name) } // Automated warning. @@ -329,87 +322,140 @@ func main() { continue } - // Only generate code for types marked - // "// +stateify savable" in one of the proceeding - // comment lines. + // Only generate code for types marked "// +stateify + // savable" in one of the proceeding comment lines. If + // the line is marked "// +stateify type" then only + // generate type information and register the type. if d.Doc == nil { continue } - savable := false + var ( + generateTypeInfo = false + generateSaverLoader = false + ) for _, l := range d.Doc.List { if l.Text == "// +stateify savable" { - savable = true + generateTypeInfo = true + generateSaverLoader = true break } + if l.Text == "// +stateify type" { + generateTypeInfo = true + } } - if !savable { + if !generateTypeInfo && !generateSaverLoader { continue } for _, gs := range d.Specs { ts := gs.(*ast.TypeSpec) - switch ts.Type.(type) { - case *ast.InterfaceType, *ast.ChanType, *ast.FuncType, *ast.ParenExpr, *ast.StarExpr: - // Don't register. - break + switch x := ts.Type.(type) { case *ast.StructType: maybeEmitImports() - ss := ts.Type.(*ast.StructType) + // Record the slot for each field. + fieldCount := 0 + fields := make(map[string]int) + emitField := func(name string) { + fmt.Fprintf(outputFile, " \"%s\",\n", name) + fields[name] = fieldCount + fieldCount++ + } + emitFieldValue := func(name string, _ string) { + emitField(name) + } + emitLoadValue := func(name, typName string) { + fmt.Fprintf(outputFile, " m.LoadValue(%d, new(%s), func(y interface{}) { x.load%s(y.(%s)) })\n", fields[name], typName, camelCased(name), typName) + } + emitLoad := func(name string) { + fmt.Fprintf(outputFile, " m.Load(%d, &x.%s)\n", fields[name], name) + } + emitLoadWait := func(name string) { + fmt.Fprintf(outputFile, " m.LoadWait(%d, &x.%s)\n", fields[name], name) + } + emitSaveValue := func(name, typName string) { + fmt.Fprintf(outputFile, " var %s %s = x.save%s()\n", name, typName, camelCased(name)) + fmt.Fprintf(outputFile, " m.SaveValue(%d, %s)\n", fields[name], name) + } + emitSave := func(name string) { + fmt.Fprintf(outputFile, " m.Save(%d, &x.%s)\n", fields[name], name) + } + + // Generate the type name method. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) + fmt.Fprintf(outputFile, "}\n\n") + + // Generate the fields method. + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return []string{\n") + scanFields(x, "", scanFunctions{ + normal: emitField, + wait: emitField, + value: emitFieldValue, + }) + fmt.Fprintf(outputFile, " }\n") + fmt.Fprintf(outputFile, "}\n\n") - // Define beforeSave if a definition was not found. This - // prevents the code from compiling if a custom beforeSave - // was defined in a file not provided to this binary and - // prevents inherited methods from being called multiple times - // by overriding them. - if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok { - fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n", ts.Name.Name) + // Define beforeSave if a definition was not found. This prevents + // the code from compiling if a custom beforeSave was defined in a + // file not provided to this binary and prevents inherited methods + // from being called multiple times by overriding them. + if _, ok := simpleMethods[method{ts.Name.Name, "beforeSave"}]; !ok && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) beforeSave() {}\n\n", ts.Name.Name) } // Generate the save method. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " x.beforeSave()\n") - scanFields(ss, scanFunctions{zerovalue: emitZeroCheck}) - scanFields(ss, scanFunctions{value: emitSaveValue}) - scanFields(ss, scanFunctions{normal: emitSave, wait: emitSave}) - fmt.Fprintf(outputFile, "}\n\n") + // + // N.B. For historical reasons, we perform the value saves first, + // and perform the value loads last. There should be no dependency + // on this specific behavior, but the ability to specify slots + // allows a manual implementation to be order-dependent. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateSave(m %sSink) {\n", ts.Name.Name, statePrefix) + fmt.Fprintf(outputFile, " x.beforeSave()\n") + scanFields(x, "", scanFunctions{zerovalue: emitZeroCheck}) + scanFields(x, "", scanFunctions{value: emitSaveValue}) + scanFields(x, "", scanFunctions{normal: emitSave, wait: emitSave}) + fmt.Fprintf(outputFile, "}\n\n") + } - // Define afterLoad if a definition was not found. We do this - // for the same reason that we do it for beforeSave. + // Define afterLoad if a definition was not found. We do this for + // the same reason that we do it for beforeSave. _, hasAfterLoad := simpleMethods[method{ts.Name.Name, "afterLoad"}] - if !hasAfterLoad { - fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n", ts.Name.Name) + if !hasAfterLoad && generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) afterLoad() {}\n\n", ts.Name.Name) } // Generate the load method. // - // Note that the manual loads always follow the - // automated loads. - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - scanFields(ss, scanFunctions{normal: emitLoad, wait: emitLoadWait}) - scanFields(ss, scanFunctions{value: emitLoadValue}) - if hasAfterLoad { - // The call to afterLoad is made conditionally, because when - // AfterLoad is called, the object encodes a dependency on - // referred objects (i.e. fields). This means that afterLoad - // will not be called until the other afterLoads are called. - fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + // N.B. See the comment above for the save method. + if generateSaverLoader { + fmt.Fprintf(outputFile, "func (x *%s) StateLoad(m %sSource) {\n", ts.Name.Name, statePrefix) + scanFields(x, "", scanFunctions{normal: emitLoad, wait: emitLoadWait}) + scanFields(x, "", scanFunctions{value: emitLoadValue}) + if hasAfterLoad { + // The call to afterLoad is made conditionally, because when + // AfterLoad is called, the object encodes a dependency on + // referred objects (i.e. fields). This means that afterLoad + // will not be called until the other afterLoads are called. + fmt.Fprintf(outputFile, " m.AfterLoad(x.afterLoad)\n") + } + fmt.Fprintf(outputFile, "}\n\n") } - fmt.Fprintf(outputFile, "}\n\n") // Add to our registration. emitRegister(ts.Name.Name) + case *ast.Ident, *ast.SelectorExpr, *ast.ArrayType: maybeEmitImports() - _, val := resolveTypeName(ts.Name.Name, ts.Type) - - // Dispatch directly. - fmt.Fprintf(outputFile, "func (x *%s) save(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.SaveValue(\"\", (%s)(*x))\n", val) + // Generate the info methods. + fmt.Fprintf(outputFile, "func (x *%s) StateTypeName() string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return \"%s.%s\"\n", *fullPkg, ts.Name.Name) fmt.Fprintf(outputFile, "}\n\n") - fmt.Fprintf(outputFile, "func (x *%s) load(m %sMap) {\n", ts.Name.Name, statePrefix) - fmt.Fprintf(outputFile, " m.LoadValue(\"\", new(%s), func(y interface{}) { *x = (%s)(y.(%s)) })\n", val, ts.Name.Name, val) + fmt.Fprintf(outputFile, "func (x *%s) StateFields() []string {\n", ts.Name.Name) + fmt.Fprintf(outputFile, " return nil\n") fmt.Fprintf(outputFile, "}\n\n") // See above. diff --git a/tools/installers/BUILD b/tools/installers/BUILD index caa7b1983..13d3cc5e0 100644 --- a/tools/installers/BUILD +++ b/tools/installers/BUILD @@ -5,15 +5,12 @@ package( licenses = ["notice"], ) -filegroup( - name = "runsc", - srcs = ["//runsc"], -) - sh_binary( name = "head", srcs = ["head.sh"], - data = [":runsc"], + data = [ + "//runsc", + ], ) sh_binary( @@ -30,6 +27,15 @@ sh_binary( ) sh_binary( + name = "containerd", + srcs = ["containerd.sh"], +) + +sh_binary( name = "shim", srcs = ["shim.sh"], + data = [ + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", + ], ) diff --git a/tools/installers/containerd.sh b/tools/installers/containerd.sh new file mode 100755 index 000000000..6b7bb261c --- /dev/null +++ b/tools/installers/containerd.sh @@ -0,0 +1,114 @@ +#!/bin/bash + +# Copyright 2019 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -xeo pipefail + +declare -r CONTAINERD_VERSION=${CONTAINERD_VERSION:-1.3.0} +declare -r CONTAINERD_MAJOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $1; }')" +declare -r CONTAINERD_MINOR="$(echo ${CONTAINERD_VERSION} | awk -F '.' '{ print $2; }')" + +# Default to an older version for crictl for containerd <= 1.2. +if [[ "${CONTAINERD_MAJOR}" -eq 1 ]] && [[ "${CONTAINERD_MINOR}" -le 2 ]]; then + declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.13.0} +else + declare -r CRITOOLS_VERSION=${CRITOOLS_VERSION:-1.18.0} +fi + +# Helper for Go packages below. +install_helper() { + PACKAGE="${1}" + TAG="${2}" + + # Clone the repository. + mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ + git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" + + # Checkout and build the repository. + (cd "${GOPATH}"/src/"${PACKAGE}" && \ + git checkout "${TAG}" && \ + make && \ + make install) +} + +# Install dependencies for the crictl tests. +while true; do + if (apt-get update && apt-get install -y \ + btrfs-tools \ + libseccomp-dev); then + break + fi + result=$? + if [[ $result -ne 100 ]]; then + exit $result + fi +done + +# Install containerd & cri-tools. +declare -rx GOPATH=$(mktemp -d --tmpdir gopathXXXXX) +install_helper github.com/containerd/containerd "v${CONTAINERD_VERSION}" "${GOPATH}" +install_helper github.com/kubernetes-sigs/cri-tools "v${CRITOOLS_VERSION}" "${GOPATH}" + +# Configure containerd-shim. +# +# Note that for versions <= 1.1 the legacy shim must be installed in /usr/bin, +# which should align with the installer script in head.sh (or master.sh). +if [[ "${CONTAINERD_MAJOR}" -le 1 ]] && [[ "${CONTAINERD_MINOR}" -lt 2 ]]; then + declare -r shim_config_path=/etc/containerd/gvisor-containerd-shim.toml + mkdir -p $(dirname ${shim_config_path}) + cat > ${shim_config_path} <<-EOF + runc_shim = "/usr/bin/containerd-shim" + +[runsc_config] + debug = "true" + debug-log = "/tmp/runsc-logs/" + strace = "true" + file-access = "shared" +EOF +fi + +# Configure CNI. +(cd "${GOPATH}" && src/github.com/containerd/containerd/script/setup/install-cni) +cat <<EOF | sudo tee /etc/cni/net.d/10-bridge.conf +{ + "cniVersion": "0.3.1", + "name": "bridge", + "type": "bridge", + "bridge": "cnio0", + "isGateway": true, + "ipMasq": true, + "ipam": { + "type": "host-local", + "ranges": [ + [{"subnet": "10.200.0.0/24"}] + ], + "routes": [{"dst": "0.0.0.0/0"}] + } +} +EOF +cat <<EOF | sudo tee /etc/cni/net.d/99-loopback.conf +{ + "cniVersion": "0.3.1", + "type": "loopback" +} +EOF + +# Configure crictl. +cat <<EOF | sudo tee /etc/crictl.yaml +runtime-endpoint: unix:///run/containerd/containerd.sock +EOF + +# Cleanup. +rm -rf "${GOPATH}" diff --git a/tools/installers/head.sh b/tools/installers/head.sh index 7fc566ebd..a613fcb5b 100755 --- a/tools/installers/head.sh +++ b/tools/installers/head.sh @@ -15,7 +15,13 @@ # limitations under the License. # Install our runtime. -$(find . -executable -type f -name runsc) install +runfiles=. +if [[ -d "$0.runfiles" ]]; then + runfiles="$0.runfiles" +fi +$(find -L "${runfiles}" -executable -type f -name runsc) install # Restart docker. -service docker restart || true +if service docker status 2>/dev/null; then + service docker restart +fi diff --git a/tools/installers/shim.sh b/tools/installers/shim.sh index f7dd790a1..8153ce283 100755 --- a/tools/installers/shim.sh +++ b/tools/installers/shim.sh @@ -14,11 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Reinstall the latest containerd shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin/gvisor-containerd-shim +# Install all the shims. +# +# Note that containerd looks at the current executable directory +# in order to find the shim binary. So we need to check in order +# of preference. The local containerd installer will install to +# /usr/local, so we use that first. +if [[ -x /usr/local/bin/containerd ]]; then + containerd_install_dir=/usr/local/bin +else + containerd_install_dir=/usr/bin +fi +runfiles=. +if [[ -d "$0.runfiles" ]]; then + runfiles="$0.runfiles" +fi +find -L "${runfiles}" -executable -type f -name containerd-shim-runsc-v1 -exec cp -L {} "${containerd_install_dir}" \; +find -L "${runfiles}" -executable -type f -name gvisor-containerd-shim -exec cp -L {} "${containerd_install_dir}" \; diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD index da4133472..8b1c717df 100644 --- a/tools/issue_reviver/github/BUILD +++ b/tools/issue_reviver/github/BUILD @@ -5,12 +5,13 @@ package(licenses = ["notice"]) go_library( name = "github", srcs = ["github.go"], + nogo = False, visibility = [ "//tools/issue_reviver:__subpackages__", ], deps = [ "//tools/issue_reviver/reviver", - "@com_github_google_go-github//github:go_default_library", + "@com_github_google_go_github_v28//github:go_default_library", "@org_golang_x_oauth2//:go_default_library", ], ) diff --git a/tools/make_release.sh b/tools/make_release.sh index b1cdd47b0..9137dd9bb 100755 --- a/tools/make_release.sh +++ b/tools/make_release.sh @@ -43,8 +43,7 @@ install_raw() { # Copy the raw file & generate a sha512sum. name=$(basename "${binary}") cp -f "${binary}" "${root}/$1" - sha512sum "${root}/$1/${name}" | \ - awk "{print $$1 \" ${name}\"}" > "${root}/$1/${name}.sha512" + (cd "${root}/$1" && sha512sum "${name}" > "${name}.sha512") done } diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 1c0d08661..433d13738 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -31,6 +31,10 @@ var ( ) // findStdPkg needs to find the bundled standard library packages. -func findStdPkg(path, GOOS, GOARCH string) (io.ReadCloser, error) { - return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) +func (i *importer) findStdPkg(path string) (io.ReadCloser, error) { + if path == "C" { + // Cgo builds cannot be analyzed. Skip. + return nil, ErrSkip + } + return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", i.GOOS, i.GOARCH, path)) } diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 6560b57c8..d399079c5 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -28,8 +28,10 @@ def _nogo_aspect_impl(target, ctx): else: return [NogoInfo()] - # Construct the Go environment from the go_context.env dictionary. - env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_context(ctx).env.items()]) + go_ctx = go_context(ctx) + + # Construct the Go environment from the go_ctx.env dictionary. + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) # Start with all target files and srcs as input. inputs = target.files.to_list() + srcs @@ -45,7 +47,7 @@ def _nogo_aspect_impl(target, ctx): "#!/bin/bash", "%s %s tool objdump %s > %s\n" % ( env_prefix, - go_context(ctx).go.path, + go_ctx.go.path, [f.path for f in binaries if f.path.endswith(".a")][0], disasm_file.path, ), @@ -53,7 +55,7 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = binaries, outputs = [disasm_file], - tools = go_context(ctx).runfiles, + tools = go_ctx.runfiles, mnemonic = "GoObjdump", progress_message = "Objdump %s" % target.label, executable = dumper, @@ -70,9 +72,11 @@ def _nogo_aspect_impl(target, ctx): ImportPath = importpath, GoFiles = [src.path for src in srcs if src.path.endswith(".go")], NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], - GOOS = go_context(ctx).goos, - GOARCH = go_context(ctx).goarch, - Tags = go_context(ctx).tags, + # Google's internal build system needs a bit more help to find std. + StdZip = go_ctx.std_zip.short_path if hasattr(go_ctx, "std_zip") else "", + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, FactMap = {}, # Constructed below. ImportMap = {}, # Constructed below. FactOutput = facts.path, @@ -110,7 +114,7 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = inputs, outputs = [facts], - tools = go_context(ctx).runfiles, + tools = go_ctx.runfiles, executable = ctx.files._nogo[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index 203cdf688..ea1e97076 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -20,6 +20,7 @@ package nogo import ( "encoding/json" + "errors" "flag" "fmt" "go/ast" @@ -54,6 +55,7 @@ type pkgConfig struct { FactMap map[string]string FactOutput string Objdump string + StdZip string } // loadFacts finds and loads facts per FactMap. @@ -89,8 +91,9 @@ func (c *pkgConfig) shouldInclude(path string) (bool, error) { // pass when a given package is not available. type importer struct { pkgConfig - fset *token.FileSet - cache map[string]*types.Package + fset *token.FileSet + cache map[string]*types.Package + lastErr error } // Import implements types.Importer.Import. @@ -109,12 +112,13 @@ func (i *importer) Import(path string) (*types.Package, error) { if !ok { // Not found in the import path. Attempt to find the package // via the standard library. - rc, err = findStdPkg(path, i.GOOS, i.GOARCH) + rc, err = i.findStdPkg(path) } else { // Open the file. rc, err = os.Open(realPath) } if err != nil { + i.lastErr = err return nil, err } defer rc.Close() @@ -128,6 +132,9 @@ func (i *importer) Import(path string) (*types.Package, error) { return gcexportdata.Read(r, i.fset, i.cache, path) } +// ErrSkip indicates the package should be skipped. +var ErrSkip = errors.New("skipped") + // checkPackage runs all analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. @@ -172,14 +179,14 @@ func checkPackage(config pkgConfig) ([]string, error) { Selections: make(map[*ast.SelectorExpr]*types.Selection), } types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) - if err != nil { - return nil, fmt.Errorf("error checking types: %v", err) + if err != nil && imp.lastErr != ErrSkip { + return nil, fmt.Errorf("error checking types: %w", err) } // Load all package facts. facts, err := facts.Decode(types, config.loadFacts) if err != nil { - return nil, fmt.Errorf("error decoding facts: %v", err) + return nil, fmt.Errorf("error decoding facts: %w", err) } // Set the binary global for use. @@ -247,6 +254,9 @@ func checkPackage(config pkgConfig) ([]string, error) { // Visit all analysis recursively. for a, _ := range analyzerConfig { + if imp.lastErr == ErrSkip { + continue // No local analysis. + } if err := visit(a); err != nil { return nil, err // Already has context. } diff --git a/tools/vm/README.md b/tools/vm/README.md index 898c95fca..1e9859e66 100644 --- a/tools/vm/README.md +++ b/tools/vm/README.md @@ -25,6 +25,12 @@ vm_image( These images can be built manually by executing the target. The output on `stdout` will be the image id (in the current project). +For example: + +``` +$ bazel build :ubuntu +``` + Images are always named per the hash of all the hermetic input scripts. This allows images to be memoized quickly and easily. diff --git a/tools/vm/defs.bzl b/tools/vm/defs.bzl index 0f67cfa92..9af5ad3b4 100644 --- a/tools/vm/defs.bzl +++ b/tools/vm/defs.bzl @@ -60,11 +60,12 @@ def _vm_image_impl(ctx): # Run the builder to generate our output. echo = ctx.actions.declare_file(ctx.label.name) resolved_inputs, argv, runfiles_manifests = ctx.resolve_command( - command = "echo -ne \"#!/bin/bash\\nset -e\\nimage=$(%s)\\necho ${image}\\n\" > %s && chmod 0755 %s" % ( - ctx.files.builder[0].path, - echo.path, - echo.path, - ), + command = "\n".join([ + "set -e", + "image=$(%s)" % ctx.files.builder[0].path, + "echo -ne \"#!/bin/bash\\necho ${image}\\n\" > %s" % echo.path, + "chmod 0755 %s" % echo.path, + ]), tools = [ctx.attr.builder], ) ctx.actions.run_shell( diff --git a/tools/vm/ubuntu1604/30_containerd.sh b/tools/vm/ubuntu1604/30_containerd.sh deleted file mode 100755 index fb3699c12..000000000 --- a/tools/vm/ubuntu1604/30_containerd.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash - -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeo pipefail - -# Helper for Go packages below. -install_helper() { - PACKAGE="${1}" - TAG="${2}" - GOPATH="${3}" - - # Clone the repository. - mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \ - git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}" - - # Checkout and build the repository. - (cd "${GOPATH}"/src/"${PACKAGE}" && \ - git checkout "${TAG}" && \ - GOPATH="${GOPATH}" make && \ - GOPATH="${GOPATH}" make install) -} - -# Install dependencies for the crictl tests. -while true; do - if (apt-get update && apt-get install -y \ - btrfs-tools \ - libseccomp-dev); then - break - fi - result=$? - if [[ $result -ne 100 ]]; then - exit $result - fi -done - -# Install containerd & cri-tools. -GOPATH=$(mktemp -d --tmpdir gopathXXXXX) -install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}" -install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}" - -# Install gvisor-containerd-shim. -declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim" -declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX) -declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX) -wget --no-verbose "${base}"/latest -O ${latest} -wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path} -chmod +x ${shim_path} -mv ${shim_path} /usr/local/bin - -# Configure containerd-shim. -declare -r shim_config_path=/etc/containerd -declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml) -mkdir -p ${shim_config_path} -cat > ${shim_config_tmp_path} <<-EOF - runc_shim = "/usr/local/bin/containerd-shim" - -[runsc_config] - debug = "true" - debug-log = "/tmp/runsc-logs/" - strace = "true" - file-access = "shared" -EOF -mv ${shim_config_tmp_path} ${shim_config_path} - -# Configure CNI. -(cd "${GOPATH}" && GOPATH="${GOPATH}" \ - src/github.com/containerd/containerd/script/setup/install-cni) - -# Cleanup the above. -rm -rf "${GOPATH}" -rm -rf "${latest}" -rm -rf "${shim_path}" -rm -rf "${shim_config_tmp_path}" diff --git a/tools/vm/ubuntu1604/25_docker.sh b/tools/vm/ubuntu1604/30_docker.sh index 11eea2d72..d393133e4 100755 --- a/tools/vm/ubuntu1604/25_docker.sh +++ b/tools/vm/ubuntu1604/30_docker.sh @@ -52,3 +52,13 @@ while true; do exit $result fi done + +# Enable experimental features, for cross-building aarch64 images. +# Enable Docker IPv6. +cat > /etc/docker/daemon.json <<EOF +{ + "experimental": true, + "fixed-cidr-v6": "2001:db8:1::/64", + "ipv6": true +} +EOF diff --git a/tools/vm/ubuntu1604/40_kokoro.sh b/tools/vm/ubuntu1604/40_kokoro.sh index 2974f156c..d3b96c9ad 100755 --- a/tools/vm/ubuntu1604/40_kokoro.sh +++ b/tools/vm/ubuntu1604/40_kokoro.sh @@ -41,7 +41,7 @@ while true; do done # junitparser is used to merge junit xml files. -pip install junitparser +pip install --no-cache-dir junitparser # We need a kbuilder user, which may already exist. useradd -c "kbuilder user" -m -s /bin/bash kbuilder || true diff --git a/website/BUILD b/website/BUILD index c97b2560b..10e0299ae 100644 --- a/website/BUILD +++ b/website/BUILD @@ -55,9 +55,7 @@ genrule( "docker run -i --user $$(id -u):$$(id -g) " + "-v $$(readlink -m $$T/output/_site):/output " + "gvisor.dev/images/jekyll " + - "/usr/gem/bin/htmlproofer " + - "--disable-external " + - "--check-html " + + "ruby /checks.rb " + "/output && " + "cp $(location //website/cmd/server) $$T/output/server && " + "tar -zcf $@ -C $$T/output . && " + @@ -138,6 +136,7 @@ docs( "//g3doc:community", "//g3doc:index", "//g3doc:roadmap", + "//g3doc:style", "//g3doc/architecture_guide:performance", "//g3doc/architecture_guide:platforms", "//g3doc/architecture_guide:resources", diff --git a/website/_includes/footer-links.html b/website/_includes/footer-links.html index 10c28ead4..2036dbaa9 100644 --- a/website/_includes/footer-links.html +++ b/website/_includes/footer-links.html @@ -15,7 +15,7 @@ <ul class="list-unstyled"> <li><a href="https://github.com/google/gvisor/issues">Issues</a></li> <li><a href="/docs">Documentation</a></li> - <li><a href="/docs/user_guide/FAQ">FAQ</a></li> + <li><a href="/docs/user_guide/faq">FAQ</a></li> </ul> </div> <div class="col-sm-3 col-md-2"> diff --git a/website/_includes/footer.html b/website/_includes/footer.html index 9cc8176f7..c1a373329 100644 --- a/website/_includes/footer.html +++ b/website/_includes/footer.html @@ -8,7 +8,7 @@ <script src="https://cdnjs.cloudflare.com/ajax/libs/d3/4.13.0/d3.min.js" integrity="sha256-hYXbQJK4qdJiAeDVjjQ9G0D6A0xLnDQ4eJI9dkm7Fpk=" crossorigin="anonymous"></script> {% if site.analytics %} -<script type="application/javascript"> +<script> var doNotTrack = false; if (!doNotTrack) { window.ga=window.ga||function(){(ga.q=ga.q||[]).push(arguments)};ga.l=+new Date; diff --git a/website/_includes/graph.html b/website/_includes/graph.html index f3a999341..ba4cf9840 100644 --- a/website/_includes/graph.html +++ b/website/_includes/graph.html @@ -1,7 +1,7 @@ {::nomarkdown} {% assign fn = include.id | remove: " " | remove: "-" | downcase %} <figure><a href="{{ include.url }}"><svg id="{{ include.id }}" width=500 height=200 onload="render_{{ fn }}()"><title>{{ include.title }}</title></svg></a></figure> -<script type="text/javascript"> +<script> function render_{{ fn }}() { d3.csv("{{ include.url }}", function(d, i, columns) { return d; // Transformed below. diff --git a/website/_includes/header-links.html b/website/_includes/header-links.html index 467bb1e72..4232fdaa5 100644 --- a/website/_includes/header-links.html +++ b/website/_includes/header-links.html @@ -2,7 +2,7 @@ <div class="container"> <div class="navbar-brand"> <a href="/"> - <img src="/assets/logos/logo_solo_on_dark.svg" height="25px" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo"/> + <img src="/assets/logos/logo_solo_on_dark.svg" height="25" class="d-inline-block align-top" style="margin-right: 10px;" alt="logo" /> gVisor </a> </div> diff --git a/website/_layouts/docs.html b/website/_layouts/docs.html index 549305089..6bc5d87db 100644 --- a/website/_layouts/docs.html +++ b/website/_layouts/docs.html @@ -47,8 +47,8 @@ categories: <h1>{{ page.title }}</h1> {% if page.editpath %} <p> - <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank"><i class="fa fa-edit fa-fw"></i> Edit this page</a> - <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank"><i class="fab fa-github fa-fw"></i> Create issue</a> + <a href="https://github.com/google/gvisor/edit/master/{{page.editpath}}" target="_blank" rel="noopener"><i class="fa fa-edit fa-fw"></i> Edit this page</a> + <a href="https://github.com/google/gvisor/issues/new?title={{page.title | url_encode}}" target="_blank" rel="noopener"><i class="fab fa-github fa-fw"></i> Create issue</a> </p> {% endif %} <div class="docs-content"> diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md index fbdd511dd..76bbabc13 100644 --- a/website/blog/2019-11-18-security-basics.md +++ b/website/blog/2019-11-18-security-basics.md @@ -44,10 +44,10 @@ into it in the next section! # Design Principles -gVisor was designed with some -[common secure design principles](https://www.owasp.org/index.php/Security_by_Design_Principles) -in mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface -Reduction and Secure-by-Default[^1]. +gVisor was designed with some common +[secure design](https://en.wikipedia.org/wiki/Secure_by_design) principles in +mind: Defense-in-Depth, Principle of Least-Privilege, Attack Surface Reduction +and Secure-by-Default[^1]. In general, Design Principles outline good engineering practices, but in the case of security, they also can be thought of as a set of tactics. In a @@ -282,16 +282,23 @@ stable. ## Notes -[^1]: [https://www.owasp.org/index.php/Security_by_Design_Principles](https://www.owasp.org/index.php/Security_by_Design_Principles) +[^1]: [https://en.wikipedia.org/wiki/Secure_by_design](https://en.wikipedia.org/wiki/Secure_by_design) [^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/) [^3]: [https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/linux/linux64_amd64.go](https://github.com/google/gvisor/blob/master/pkg/sentry/syscalls/syscalls.go) -[^4]: Internally that is, it doesn't call to the Host OS to implement them, in - fact that is explicitly disallowed, more on that in the future. + +<!-- mdformat off(mdformat formats this into multiple lines) --> +[^4]: Internally that is, it doesn't call to the Host OS to implement them, in fact that is explicitly disallowed, more on that in the future. +<!-- mdformat on --> + [^5]: [https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345](https://elixir.bootlin.com/linux/latest/source/arch/x86/entry/syscalls/syscall_64.tbl#L345) [^6]: [https://github.com/google/gvisor/tree/master/runsc/boot/filter](https://github.com/google/gvisor/tree/master/runsc/boot/filter) [^7]: [https://en.wikipedia.org/wiki/Dirty_COW](https://en.wikipedia.org/wiki/Dirty_COW) [^8]: [https://github.com/google/gvisor/blob/master/runsc/boot/config.go](https://github.com/google/gvisor/blob/master/runsc/boot/config.go) -[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_\(protocol\)) + +<!-- mdformat off(mdformat breaks this url by escaping the parenthesis) --> +[^9]: [https://en.wikipedia.org/wiki/9P_(protocol)](https://en.wikipedia.org/wiki/9P_(protocol)) +<!-- mdformat on --> + [^10]: [https://gvisor.dev/docs/user_guide/networking/#network-passthrough](https://gvisor.dev/docs/user_guide/networking/#network-passthrough) [^11]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ptrace/subprocess.go#L390) [^12]: [https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182](https://github.com/google/gvisor/blob/c7e901f47a09eaac56bd4813227edff016fa6bff/pkg/sentry/platform/ring0/kernel_amd64.go#L182) diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go index 7c8bc9bfa..c401b6abd 100644 --- a/website/cmd/server/main.go +++ b/website/cmd/server/main.go @@ -35,6 +35,10 @@ var redirects = map[string]string{ // For links. "/faq": "/docs/user_guide/faq/", + // From 2020-05-12 to 2020-06-30, the FAQ URL was uppercase. Redirect that + // back to maintain any links. + "/docs/user_guide/FAQ/": "/docs/user_guide/faq/", + // Redirects to compatibility docs. "/c": "/docs/user_guide/compatibility/", "/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/", diff --git a/website/defs.bzl b/website/defs.bzl index ead6a3067..f52946c15 100644 --- a/website/defs.bzl +++ b/website/defs.bzl @@ -1,5 +1,7 @@ """Wrappers for website documentation.""" +load("//tools:defs.bzl", "short_path") + # DocInfo is a provider which simple adds sufficient metadata to the source # files (and additional data files) so that a jeyll header can be constructed # dynamically. This is done the via BUILD system so that the plain @@ -29,7 +31,7 @@ def _doc_impl(ctx): category = ctx.attr.category, subcategory = ctx.attr.subcategory, weight = ctx.attr.weight, - editpath = ctx.files.src[0].short_path, + editpath = short_path(ctx.files.src[0].short_path), authors = ctx.attr.authors, ), ] |