diff options
416 files changed, 18142 insertions, 7073 deletions
@@ -1,10 +1,17 @@ load("//tools:defs.bzl", "build_test", "gazelle", "go_path") +load("//tools/nogo:defs.bzl", "nogo_config") load("//website:defs.bzl", "doc") package(licenses = ["notice"]) exports_files(["LICENSE"]) +nogo_config( + name = "nogo_config", + srcs = ["nogo.yaml"], + visibility = ["//:sandbox"], +) + doc( name = "contributing", src = "CONTRIBUTING.md", @@ -86,6 +93,7 @@ go_path( "//runsc/cli", "//shim/v1/cli", "//shim/v2/cli", + "//webhook/pkg/cli", # Packages that are not dependencies of the above. "//pkg/sentry/kernel/memevent", @@ -156,12 +156,24 @@ syscall-tests: ## Run all system call tests. @$(call submake,test TARGETS="test/syscalls/...") %-runtime-tests: load-runtimes_% +ifeq ($(PARTITION),) + @$(eval PARTITION := 1) +endif +ifeq ($(TOTAL_PARTITIONS),) + @$(eval TOTAL_PARTITIONS := 1) +endif @$(call submake,install-test-runtime) - @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + @$(call submake,test-runtime OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*") %-runtime-tests_vfs2: load-runtimes_% +ifeq ($(PARTITION),) + @$(eval PARTITION := 1) +endif +ifeq ($(TOTAL_PARTITIONS),) + @$(eval TOTAL_PARTITIONS := 1) +endif @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*") do-tests: runsc @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") @@ -210,6 +222,15 @@ iptables-tests: load-iptables @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") .PHONY: iptables-tests +# Run the iptables tests with runsc only. Useful for developing to skip runc +# testing. +iptables-runsc-tests: load-iptables + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter + @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") + @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") +.PHONY: iptables-runsc-tests + packetdrill-tests: load-packetdrill @$(call submake,install-test-runtime RUNTIME="packetdrill") @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')") @@ -240,6 +261,61 @@ containerd-tests: containerd-test-1.3.4 containerd-tests: containerd-test-1.4.0-beta.0 ## +## Benchmarks. +## +## Targets to run benchmarks. See //test/benchmarks for details. +## +## common arguments: +## RUNTIME_ARGS - arguments to runsc placed in /etc/docker/daemon.json +## e.g. "--platform=ptrace" +## BENCHMARKS_PROJECT - BigQuery project to which to send data. +## BENCHMARKS_DATASET - BigQuery dataset to which to send data. +## BENCHMARKS_TABLE - BigQuery table to which to send data. +## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go. +## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run. +## BENCHMARKS_OFFICIAL - marks the data as official. +## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm). +## +RUNTIME_ARGS := --net-raw --platform=ptrace +BENCHMARKS_PROJECT := gvisor-benchmarks +BENCHMARKS_DATASET := kokoro +BENCHMARKS_TABLE := benchmarks +BENCHMARKS_SUITE := start +BENCHMARKS_UPLOAD := false +BENCHMARKS_OFFICIAL := false +BENCHMARKS_PLATFORMS := ptrace + +init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema +## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop. + $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \ + --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)") +.PHONY: init-benchmark-table + +benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. + $(call submake, run-benchmark RUNTIME="runc") + $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS),\ + $(call submake,benchmark-platform RUNTIME="$(PLATFORM)" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw --vfs2") && \ + $(call submake,benchmark-platform RUNTIME="$(PLATFORM)_vfs1" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw")) +.PHONY: benchmark-platforms + +benchmark-platform: ## Installs a runtime with the given platform args. + @$(call submake,install-test-runtime ARGS="$(RUNTIME_ARGS)") + @$(call submake, run-benchmark) +.PHONY: benchmark-platform + +run-benchmark: ## Runs single benchmark and optionally sends data to BigQuery. + $(eval T := $(shell mktemp /tmp/logs.$(RUNTIME).XXXXXX)) + $(call submake,sudo TARGETS="$(TARGETS)" ARGS="--runtime=$(RUNTIME) $(ARGS)" | tee $(T)) + @if [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \ + @$(call submake,run TARGETS=tools/parsers:parser ARGS="parse --file=$(T) \ + --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \ + --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \ + --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \ + fi; + rm -rf $T +.PHONY: run-benchmark + +## ## Website & documentation helpers. ## ## The website is built from repository documentation and wrappers, using @@ -260,7 +336,7 @@ website-build: load-jekyll ## Build the site image locally. .PHONY: website-build website-server: website-build ## Run a local server for development. - @docker run -i -p 8080:8080 gvisor.dev/images/website + @docker run -i -p 8080:8080 $(WEBSITE_IMAGE) .PHONY: website-server website-push: website-build ## Push a new image and update the service. @@ -23,13 +23,13 @@ bazel_skylib_workspace() http_archive( name = "io_bazel_rules_go", - sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00", patch_args = ["-p1"], patches = [ # Newer versions of the rules_go rules will automatically strip test # binaries of symbols, which we don't want. "//tools:rules_go.patch", ], + sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", "https://github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", @@ -49,7 +49,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() -go_register_toolchains(go_version = "1.14.2") +go_register_toolchains(go_version = "1.15.2") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") @@ -58,7 +58,7 @@ gazelle_dependencies() # The com_google_protobuf repository below would trigger downloading a older # version of org_golang_x_sys. If putting this repository statment in a place # after that of the com_google_protobuf, this statement will not work as -# expectd to download a new version of org_golang_x_sys. +# expected to download a new version of org_golang_x_sys. go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", @@ -222,8 +222,8 @@ go_repository( go_repository( name = "com_github_google_uuid", importpath = "github.com/google/uuid", - sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=", - version = "v1.0.0", + sum = "h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=", + version = "v1.1.1", ) go_repository( @@ -328,8 +328,8 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", - version = "v0.0.0-20201002184944-ecd9fd270d5d", + sum = "h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk=", + version = "v0.0.0-20201021000207-d49c4edd7d96", ) go_repository( @@ -349,8 +349,8 @@ go_repository( go_repository( name = "com_github_golang_protobuf", importpath = "github.com/golang/protobuf", - sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=", - version = "v1.4.2", + sum = "h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0=", + version = "v1.4.1", ) go_repository( @@ -412,7 +412,7 @@ go_repository( go_repository( name = "com_github_konsorten_go_windows_terminal_sequences", importpath = "github.com/konsorten/go-windows-terminal-sequences", - sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", + sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=", version = "v1.0.3", ) @@ -461,8 +461,8 @@ go_repository( go_repository( name = "org_uber_go_multierr", importpath = "go.uber.org/multierr", - sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=", - version = "v1.2.0", + sum = "h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=", + version = "v1.6.0", ) go_repository( @@ -482,8 +482,8 @@ go_repository( go_repository( name = "co_honnef_go_tools", importpath = "honnef.co/go/tools", - sum = "h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM=", - version = "v0.0.1-2019.2.3", + sum = "h1:W18jzjh8mfPez+AwGLxmOImucz/IFjpNlrKVnaj2YVc=", + version = "v0.0.1-2020.1.6", ) go_repository( @@ -623,8 +623,8 @@ go_repository( go_repository( name = "com_github_google_go_cmp", importpath = "github.com/google/go-cmp", - sum = "h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=", - version = "v0.5.1", + sum = "h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI=", + version = "v0.5.3-0.20201020212313-ab46b8bd0abd", ) go_repository( @@ -721,8 +721,8 @@ go_repository( go_repository( name = "com_github_spf13_pflag", importpath = "github.com/spf13/pflag", - sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=", - version = "v1.0.1-0.20171106142849-4c012f6dcd95", + sum = "h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=", + version = "v1.0.5", ) go_repository( @@ -763,15 +763,15 @@ go_repository( go_repository( name = "org_golang_google_genproto", importpath = "google.golang.org/genproto", - sum = "h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=", - version = "v0.0.0-20200526211855-cb27e3aa2013", + sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", + version = "v0.0.0-20200117163144-32f20d992d24", ) go_repository( name = "org_golang_google_protobuf", importpath = "google.golang.org/protobuf", - sum = "h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=", - version = "v1.25.1-0.20200808011614-a180de9f97d9", + sum = "h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A=", + version = "v1.25.1-0.20201020201750-d3470999428b", ) go_repository( @@ -1032,3 +1032,356 @@ go_repository( sum = "h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=", version = "v1.0.0", ) + +go_repository( + name = "com_github_azure_go_autorest_autorest", + importpath = "github.com/Azure/go-autorest/autorest", + sum = "h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs=", + version = "v0.9.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_adal", + importpath = "github.com/Azure/go-autorest/autorest/adal", + sum = "h1:q2gDruN08/guU9vAjuPWff0+QIrpH6ediguzdAzXAUU=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_date", + importpath = "github.com/Azure/go-autorest/autorest/date", + sum = "h1:YGrhWfrgtFs84+h0o46rJrlmsZtyZRg470CqAXTZaGM=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_mocks", + importpath = "github.com/Azure/go-autorest/autorest/mocks", + sum = "h1:Ww5g4zThfD/6cLb4z6xxgeyDa7QDkizMkJKe0ysZXp0=", + version = "v0.2.0", +) + +go_repository( + name = "com_github_azure_go_autorest_logger", + importpath = "github.com/Azure/go-autorest/logger", + sum = "h1:ruG4BSDXONFRrZZJ2GUXDiUyVpayPmb1GnWeHDdaNKY=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_azure_go_autorest_tracing", + importpath = "github.com/Azure/go-autorest/tracing", + sum = "h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_dgrijalva_jwt_go", + importpath = "github.com/dgrijalva/jwt-go", + sum = "h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=", + version = "v3.2.0+incompatible", +) + +go_repository( + name = "com_github_docker_spdystream", + importpath = "github.com/docker/spdystream", + sum = "h1:cenwrSVm+Z7QLSV/BsnenAOcDXdX4cMv4wP0B/5QbPg=", + version = "v0.0.0-20160310174837-449fdfce4d96", +) + +go_repository( + name = "com_github_elazarl_goproxy", + importpath = "github.com/elazarl/goproxy", + sum = "h1:p1yVGRW3nmb85p1Sh1ZJSDm4A4iKLS5QNbvUHMgGu/M=", + version = "v0.0.0-20170405201442-c4fc26588b6e", +) + +go_repository( + name = "com_github_emicklei_go_restful", + importpath = "github.com/emicklei/go-restful", + sum = "h1:H2pdYOb3KQ1/YsqVWoWNLQO+fusocsw354rqGTZtAgw=", + version = "v0.0.0-20170410110728-ff4f55a20633", +) + +go_repository( + name = "com_github_evanphx_json_patch", + importpath = "github.com/evanphx/json-patch", + sum = "h1:fUDGZCv/7iAN7u0puUVhvKCcsR6vRfwrJatElLBEf0I=", + version = "v4.2.0+incompatible", +) + +go_repository( + name = "com_github_fsnotify_fsnotify", + importpath = "github.com/fsnotify/fsnotify", + sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=", + version = "v1.4.7", +) + +go_repository( + name = "com_github_ghodss_yaml", + importpath = "github.com/ghodss/yaml", + sum = "h1:ZktWZesgun21uEDrwW7iEV1zPCGQldM2atlJZ3TdvVM=", + version = "v0.0.0-20150909031657-73d445a93680", +) + +go_repository( + name = "com_github_go_logr_logr", + importpath = "github.com/go-logr/logr", + sum = "h1:M1Tv3VzNlEHg6uyACnRdtrploV2P7wZqH8BoQMtz0cg=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_go_openapi_jsonpointer", + importpath = "github.com/go-openapi/jsonpointer", + sum = "h1:wSt/4CYxs70xbATrGXhokKF1i0tZjENLOo1ioIO13zk=", + version = "v0.0.0-20160704185906-46af16f9f7b1", +) + +go_repository( + name = "com_github_go_openapi_jsonreference", + importpath = "github.com/go-openapi/jsonreference", + sum = "h1:tF+augKRWlWx0J0B7ZyyKSiTyV6E1zZe+7b3qQlcEf8=", + version = "v0.0.0-20160704190145-13c6e3589ad9", +) + +go_repository( + name = "com_github_go_openapi_spec", + importpath = "github.com/go-openapi/spec", + sum = "h1:C1JKChikHGpXwT5UQDFaryIpDtyyGL/CR6C2kB7F1oc=", + version = "v0.0.0-20160808142527-6aced65f8501", +) + +go_repository( + name = "com_github_go_openapi_swag", + importpath = "github.com/go-openapi/swag", + sum = "h1:zP3nY8Tk2E6RTkqGYrarZXuzh+ffyLDljLxCy1iJw80=", + version = "v0.0.0-20160704191624-1d0bd113de87", +) + +go_repository( + name = "com_github_google_gofuzz", + importpath = "github.com/google/gofuzz", + sum = "h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_googleapis_gnostic", + build_file_proto_mode = "disable_global", + importpath = "github.com/googleapis/gnostic", + sum = "h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k=", + version = "v0.0.0-20170729233727-0c5108395e2d", +) + +go_repository( + name = "com_github_gophercloud_gophercloud", + importpath = "github.com/gophercloud/gophercloud", + sum = "h1:P/nh25+rzXouhytV2pUHBb65fnds26Ghl8/391+sT5o=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_gregjones_httpcache", + importpath = "github.com/gregjones/httpcache", + sum = "h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=", + version = "v0.0.0-20180305231024-9cad4c3443a7", +) + +go_repository( + name = "com_github_hpcloud_tail", + importpath = "github.com/hpcloud/tail", + sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_imdario_mergo", + importpath = "github.com/imdario/mergo", + sum = "h1:JboBksRwiiAJWvIYJVo46AfV+IAIKZpfrSzVKj42R4Q=", + version = "v0.3.5", +) + +go_repository( + name = "com_github_json_iterator_go", + importpath = "github.com/json-iterator/go", + sum = "h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=", + version = "v1.1.7", +) + +go_repository( + name = "com_github_mailru_easyjson", + importpath = "github.com/mailru/easyjson", + sum = "h1:TpvdAwDAt1K4ANVOfcihouRdvP+MgAfDWwBuct4l6ZY=", + version = "v0.0.0-20160728113105-d5b7844b561a", +) + +go_repository( + name = "com_github_mattbaird_jsonpatch", + importpath = "github.com/mattbaird/jsonpatch", + sum = "h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8=", + version = "v0.0.0-20171005235357-81af80346b1a", +) + +go_repository( + name = "com_github_modern_go_concurrent", + importpath = "github.com/modern-go/concurrent", + sum = "h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=", + version = "v0.0.0-20180306012644-bacd9c7ef1dd", +) + +go_repository( + name = "com_github_modern_go_reflect2", + importpath = "github.com/modern-go/reflect2", + sum = "h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=", + version = "v1.0.1", +) + +go_repository( + name = "com_github_munnerz_goautoneg", + importpath = "github.com/munnerz/goautoneg", + sum = "h1:7PxY7LVfSZm7PEeBTyK1rj1gABdCO2mbri6GKO1cMDs=", + version = "v0.0.0-20120707110453-a547fc61f48d", +) + +go_repository( + name = "com_github_mxk_go_flowrate", + importpath = "github.com/mxk/go-flowrate", + sum = "h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=", + version = "v0.0.0-20140419014527-cca7078d478f", +) + +go_repository( + name = "com_github_nytimes_gziphandler", + importpath = "github.com/NYTimes/gziphandler", + sum = "h1:lsxEuwrXEAokXB9qhlbKWPpo3KMLZQ5WB5WLQRW1uq0=", + version = "v0.0.0-20170623195520-56545f4a5d46", +) + +go_repository( + name = "com_github_onsi_ginkgo", + importpath = "github.com/onsi/ginkgo", + sum = "h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w=", + version = "v1.8.0", +) + +go_repository( + name = "com_github_onsi_gomega", + importpath = "github.com/onsi/gomega", + sum = "h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo=", + version = "v1.5.0", +) + +go_repository( + name = "com_github_peterbourgon_diskv", + importpath = "github.com/peterbourgon/diskv", + sum = "h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=", + version = "v2.0.1+incompatible", +) + +go_repository( + name = "com_github_puerkitobio_purell", + importpath = "github.com/PuerkitoBio/purell", + sum = "h1:0GoNN3taZV6QI81IXgCbxMyEaJDXMSIjArYBCYzVVvs=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_puerkitobio_urlesc", + importpath = "github.com/PuerkitoBio/urlesc", + sum = "h1:JCHLVE3B+kJde7bIEo5N4J+ZbLhp0J1Fs+ulyRws4gE=", + version = "v0.0.0-20160726150825-5bd2802263f2", +) + +go_repository( + name = "com_github_spf13_afero", + importpath = "github.com/spf13/afero", + sum = "h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc=", + version = "v1.2.2", +) + +go_repository( + name = "in_gopkg_fsnotify_v1", + importpath = "gopkg.in/fsnotify.v1", + sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=", + version = "v1.4.7", +) + +go_repository( + name = "in_gopkg_inf_v0", + importpath = "gopkg.in/inf.v0", + sum = "h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=", + version = "v0.9.1", +) + +go_repository( + name = "in_gopkg_tomb_v1", + importpath = "gopkg.in/tomb.v1", + sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=", + version = "v1.0.0-20141024135613-dd632973f1e7", +) + +go_repository( + name = "io_k8s_api", + build_file_proto_mode = "disable_global", + importpath = "k8s.io/api", + sum = "h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs=", + version = "v0.16.13", +) + +go_repository( + name = "io_k8s_apimachinery", + build_file_proto_mode = "disable_global", + importpath = "k8s.io/apimachinery", + sum = "h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc=", + version = "v0.16.14-rc.0", +) + +go_repository( + name = "io_k8s_client_go", + importpath = "k8s.io/client-go", + sum = "h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag=", + version = "v0.16.13", +) + +go_repository( + name = "io_k8s_gengo", + importpath = "k8s.io/gengo", + sum = "h1:4s3/R4+OYYYUKptXPhZKjQ04WJ6EhQQVFdjOFvCazDk=", + version = "v0.0.0-20190128074634-0689ccc1d7d6", +) + +go_repository( + name = "io_k8s_klog", + importpath = "k8s.io/klog", + sum = "h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8=", + version = "v1.0.0", +) + +go_repository( + name = "io_k8s_kube_openapi", + importpath = "k8s.io/kube-openapi", + sum = "h1:PsbYeEz2x7ll6JYUzBEG+DT78910DDTlvn5Ma10F5/E=", + version = "v0.0.0-20200410163147-594e756bea31", +) + +go_repository( + name = "io_k8s_sigs_structured_merge_diff", + importpath = "sigs.k8s.io/structured-merge-diff", + sum = "h1:4Z09Hglb792X0kfOBBJUPFEyvVfQWrYT/l8h5EKA6JQ=", + version = "v0.0.0-20190525122527-15d366b2352e", +) + +go_repository( + name = "io_k8s_sigs_yaml", + importpath = "sigs.k8s.io/yaml", + sum = "h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs=", + version = "v1.1.0", +) + +go_repository( + name = "io_k8s_utils", + importpath = "k8s.io/utils", + sum = "h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=", + version = "v0.0.0-20190801114015-581e00157fb1", +) diff --git a/g3doc/user_guide/containerd/quick_start.md b/g3doc/user_guide/containerd/quick_start.md index b6a3186d8..a98fe5c4a 100644 --- a/g3doc/user_guide/containerd/quick_start.md +++ b/g3doc/user_guide/containerd/quick_start.md @@ -1,7 +1,7 @@ # Containerd Quick Start -This document describes how to install and configure `containerd-shim-runsc-v1` -using the containerd runtime handler support on `containerd` 1.2 or later. +This document describes how to use `containerd-shim-runsc-v1` with the +containerd runtime handler support on `containerd` 1.2 or later. > ⚠️ NOTE: If you are using Kubernetes and set up your cluster using kubeadm you > may run into issues. See the [FAQ](../FAQ.md#runtime-handler) for details. diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index abb9e8582..c3ced9d61 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -13,15 +13,19 @@ To download and install the latest release manually follow these steps: ( set -e URL=https://storage.googleapis.com/gvisor/releases/release/latest - wget ${URL}/runsc ${URL}/runsc.sha512 - sha512sum -c runsc.sha512 - rm -f runsc.sha512 - sudo mv runsc /usr/local/bin - sudo chmod a+rx /usr/local/bin/runsc + wget ${URL}/runsc ${URL}/runsc.sha512 \ + ${URL}/gvisor-containerd-shim ${URL}/gvisor-containerd-shim.sha512 \ + ${URL}/containerd-shim-runsc-v1 ${URL}/containerd-shim-runsc-v1.sha512 + sha512sum -c runsc.sha512 \ + -c gvisor-containerd-shim.sha512 \ + -c containerd-shim-runsc-v1.sha512 + rm -f *.sha512 + chmod a+rx runsc gvisor-containerd-shim containerd-shim-runsc-v1 + sudo mv runsc gvisor-containerd-shim containerd-shim-runsc-v1 /usr/local/bin ) ``` -To install gVisor with Docker, run the following commands: +To install gVisor as a Docker runtime, run the following commands: ```bash /usr/local/bin/runsc install @@ -165,5 +169,6 @@ You can use this link with the steps described in Note that `apt` installation of a specific point release is not supported. After installation, try out `runsc` by following the -[Docker Quick Start](./quick_start/docker.md) or +[Docker Quick Start](./quick_start/docker.md), +[Containerd QuickStart](./containerd/quick_start.md), or [OCI Quick Start](./quick_start/oci.md). @@ -29,11 +29,12 @@ require ( github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect - github.com/google/go-cmp v0.5.1 // indirect + github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd // indirect github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect github.com/hashicorp/go-multierror v1.0.0 // indirect github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect + github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/opencontainers/runc v0.1.1 // indirect @@ -43,12 +44,13 @@ require ( github.com/urfave/cli v1.22.2 // indirect github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect - go.uber.org/atomic v1.7.0 // indirect - go.uber.org/multierr v1.2.0 // indirect + go.uber.org/multierr v1.6.0 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect - golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d // indirect + golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 // indirect google.golang.org/grpc v1.29.0 // indirect - google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 // indirect - gopkg.in/yaml.v2 v2.2.8 // indirect + google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect gotest.tools v2.2.0+incompatible // indirect + k8s.io/api v0.16.13 + k8s.io/apimachinery v0.16.14-rc.0 + k8s.io/client-go v0.16.13 ) @@ -13,6 +13,13 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= @@ -27,6 +34,9 @@ github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg3 github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= github.com/Microsoft/hcsshim v0.8.10 h1:k5wTrpnVU2/xv8ZuzGkbXVd3js5zJ8RnumPo5RxiIxU= github.com/Microsoft/hcsshim v0.8.10/go.mod h1:g5uw8EV2mAlzqe94tfNBNdr89fnbD/n3HV0OhsddkmM= +github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= +github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= @@ -68,8 +78,11 @@ github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQa github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w= @@ -80,14 +93,25 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= +github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= +github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= +github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg= +github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= +github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= @@ -96,9 +120,11 @@ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1 github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= +github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -106,6 +132,7 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -125,11 +152,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI= +github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= +github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -137,18 +167,28 @@ github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hf github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k= +github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= +github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= @@ -164,8 +204,25 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8= +github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= @@ -181,9 +238,12 @@ github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNia github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= @@ -196,12 +256,18 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= @@ -217,12 +283,15 @@ go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= -go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -247,6 +316,7 @@ golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -275,8 +345,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -285,6 +358,7 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -295,6 +369,7 @@ golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -304,6 +379,7 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -322,8 +398,8 @@ golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE= -golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk= +golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -349,9 +425,8 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -359,7 +434,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -368,13 +442,18 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w= -google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A= +google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -383,4 +462,22 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +k8s.io/api v0.16.13 h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs= +k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs= +k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= +k8s.io/apimachinery v0.16.14-rc.0 h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc= +k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= +k8s.io/client-go v0.16.13 h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag= +k8s.io/client-go v0.16.13/go.mod h1:UKvVT4cajC2iN7DCjLgT0KVY/cbY6DGdUCyRiIfws5M= +k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0= +k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= +k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= +k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8= +k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= +k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E= +k8s.io/utils v0.0.0-20190801114015-581e00157fb1 h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE= +k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI= +sigs.k8s.io/yaml v1.1.0 h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= diff --git a/images/README.md b/images/README.md index 9880946a6..297c7c3f3 100644 --- a/images/README.md +++ b/images/README.md @@ -41,9 +41,9 @@ All images will be tagged and memoized using a hash of the directory contents. As a result, every image should be made completely reproducible if possible. This means using fixed tags and fixed versions whenever feasible. -Notes that images should also be made architecture-independent if possible. The -build scripts will handling loading the appropriate architecture onto the -machine and tagging it with the single canonical tag. +Note that images should also be made architecture-independent if possible. The +build scripts will handle loading the appropriate architecture onto the machine +and tagging it with the single canonical tag. Add a `load-<image>` dependency in the Makefile if the image is required for a particular set of tests. This target will pull the tag from the image repository diff --git a/images/defs.bzl b/images/defs.bzl index 61d7bbf73..c1f96e312 100644 --- a/images/defs.bzl +++ b/images/defs.bzl @@ -2,30 +2,33 @@ def _docker_image_impl(ctx): importer = ctx.actions.declare_file(ctx.label.name) + importer_content = [ "#!/bin/bash", "set -euo pipefail", + "source_file='%s'" % ctx.file.data.path, + "if [[ ! -f \"$source_file\" ]]; then", + " source_file='%s'" % ctx.file.data.short_path, + "fi", "exec docker import " + " ".join([ "-c '%s'" % attr for attr in ctx.attr.statements - ]) + " " + " ".join([ - "'%s'" % f.path - for f in ctx.files.data - ]) + " $1", + ]) + " \"$source_file\" $1", "", ] + ctx.actions.write(importer, "\n".join(importer_content), is_executable = True) return [DefaultInfo( - runfiles = ctx.runfiles(ctx.files.data), + runfiles = ctx.runfiles([ctx.file.data]), executable = importer, )] docker_image = rule( implementation = _docker_image_impl, - doc = "Tool to load a Docker image; takes a single parameter (image name).", + doc = "Tool to import a Docker image; takes a single parameter (image name).", attrs = { "statements": attr.string_list(doc = "Extra Dockerfile directives."), - "data": attr.label_list(doc = "All image data."), + "data": attr.label(doc = "Image filesystem tarball", allow_single_file = [".tgz", ".tar.gz"]), }, executable = True, ) diff --git a/nogo.yaml b/nogo.yaml new file mode 100644 index 000000000..5c1737f59 --- /dev/null +++ b/nogo.yaml @@ -0,0 +1,253 @@ +groups: + # We define three basic groups: generated (all generated files), + # external (all files outside the repository), and internal (all + # files within the local repository). We can't enforce many style + # checks on generated and external code, so enable those cases + # selectively for analyzers below. + - name: generated + regex: "^(bazel-genfiles|bazel-out|bazel-bin)/" + default: true + - name: external + regex: "^external/" + default: false + - name: internal + regex: ".*" + default: true +global: + generated: + suppress: + # Suppress the basic style checks for + # generated code, but keep the analysis + # that are required for quality & security. + - "should not use ALL_CAPS in Go names" + - "should not use underscores" + - "comment on exported" + - "methods on the same type should have the same receiver name" + - "at least one file in a package" + - "package comment should be of the form" + # Generated code may have dead code paths. + - "identical build constraints" + - "no value of type" + - "is never used" + # go_embed_data rules generate unicode literals. + - "string literal contains the Unicode format character" + - "string literal contains the Unicode control character" + - "string literal contains Unicode control characters" + - "string literal contains Unicode format and control characters" + # Some external code will generate protov1 + # implementations. These should be ignored. + - "proto.* is deprecated" + - "xxx_messageInfo_.*" + - "receiver name should be a reflection of its identity" + # Generated gRPC code is not compliant either. + - "error strings should not be capitalized" + - "grpc.Errorf is deprecated" + # Generated proto code does not always follow capitalization conventions. + - "(field|method|struct|type) .* should be .*" + # Generated proto code sometimes duplicates imports with aliases. + - "duplicate import" + internal: + suppress: + # We use ALL_CAPS for system definitions, + # which are common enough in the code base + # that we shouldn't annotate exceptions. + # + # Same story for underscores. + - "should not use ALL_CAPS in Go names" + - "should not use underscores in Go names" + exclude: + # A variety of staticcheck and stylecheck + # rules apply here. These should be fixed + # and removed from here, and the global + # rules should be used sparingly. + - pkg/abi/linux/fuse.go:22 + - pkg/abi/linux/fuse.go:25 + - pkg/abi/linux/socket.go:113 + - pkg/abi/linux/tty.go:73 + - pkg/cpuid/cpuid_x86.go:675 + - pkg/gohacks/gohacks_unsafe.go:33 + - pkg/log/json.go:30 + - pkg/log/log.go:359 + - pkg/metric/metric_test.go:20 + - pkg/p9/p9test/client_test.go:687 + - pkg/p9/transport_test.go:196 + - pkg/pool/pool.go:15 + - pkg/refs/refcounter.go:510 + - pkg/refs/refcounter_test.go:169 + - pkg/refs_vfs2/refs.go:16 + - pkg/safemem/block_unsafe.go:89 + - pkg/seccomp/seccomp.go:82 + - pkg/segment/test/set_functions.go:15 + - pkg/sentry/arch/signal.go:166 + - pkg/sentry/arch/signal.go:171 + - pkg/sentry/control/pprof.go:196 + - pkg/sentry/devices/memdev/full.go:58 + - pkg/sentry/devices/memdev/null.go:59 + - pkg/sentry/devices/memdev/random.go:68 + - pkg/sentry/devices/memdev/zero.go:86 + - pkg/sentry/fdimport/fdimport.go:15 + - pkg/sentry/fs/attr.go:257 + - pkg/sentry/fsbridge/fs.go:116 + - pkg/sentry/fsbridge/vfs.go:124 + - pkg/sentry/fsbridge/vfs.go:70 + - pkg/sentry/fs/copy_up.go:365 + - pkg/sentry/fs/copy_up_test.go:65 + - pkg/sentry/fs/dev/net_tun.go:161 + - pkg/sentry/fs/dev/net_tun.go:63 + - pkg/sentry/fs/dev/null.go:97 + - pkg/sentry/fs/dirent_cache.go:64 + - pkg/sentry/fs/fdpipe/pipe_opener_test.go:366 + - pkg/sentry/fs/file_overlay.go:327 + - pkg/sentry/fs/file_overlay.go:524 + - pkg/sentry/fs/filetest/filetest.go:55 + - pkg/sentry/fs/filetest/filetest.go:60 + - pkg/sentry/fs/fs.go:77 + - pkg/sentry/fs/fsutil/file.go:290 + - pkg/sentry/fs/fsutil/file.go:346 + - pkg/sentry/fs/fsutil/host_file_mapper.go:105 + - pkg/sentry/fs/fsutil/inode_cached.go:676 + - pkg/sentry/fs/fsutil/inode_cached.go:772 + - pkg/sentry/fs/gofer/attr.go:120 + - pkg/sentry/fs/gofer/fifo.go:33 + - pkg/sentry/fs/gofer/inode.go:410 + - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97 + - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92 + - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44 + - pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91 + - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93 + - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66 + - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53 + - pkg/sentry/fsimpl/fuse/request_response.go:71 + - pkg/sentry/fsimpl/signalfd/signalfd.go:15 + - pkg/sentry/memmap/memmap.go:103 + - pkg/sentry/memmap/memmap.go:163 + - pkg/sentry/mm/aio_context.go:208 + - pkg/sentry/mm/pma.go:683 + - pkg/sentry/usage/cpu.go:42 + - pkg/shim/runsc/runsc.go:16 + - pkg/shim/runsc/utils.go:16 + - pkg/shim/v1/proc/deleted_state.go:16 + - pkg/shim/v1/proc/exec.go:16 + - pkg/shim/v1/proc/exec_state.go:16 + - pkg/shim/v1/proc/init.go:16 + - pkg/shim/v1/proc/init_state.go:16 + - pkg/shim/v1/proc/io.go:16 + - pkg/shim/v1/proc/process.go:16 + - pkg/shim/v1/proc/types.go:16 + - pkg/shim/v1/proc/utils.go:16 + - pkg/shim/v1/shim/api.go:16 + - pkg/shim/v1/shim/platform.go:16 + - pkg/shim/v1/shim/service.go:16 + - pkg/shim/v1/utils/annotations.go:15 + - pkg/shim/v1/utils/utils.go:15 + - pkg/shim/v1/utils/volumes.go:15 + - pkg/shim/v2/api.go:16 + - pkg/shim/v2/epoll.go:18 + - pkg/shim/v2/options/options.go:15 + - pkg/shim/v2/options/options.go:24 + - pkg/shim/v2/options/options.go:26 + - pkg/shim/v2/runtimeoptions/runtimeoptions.go:16 + - pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go # Generated: exempt all. + - pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22 + - pkg/shim/v2/service.go:15 + - pkg/shim/v2/service_linux.go:18 + - pkg/state/tests/integer_test.go:23 + - pkg/state/tests/integer_test.go:28 + - pkg/sync/rwmutex_test.go:105 + - pkg/syserr/host_linux.go:35 + - pkg/usermem/addr.go:34 + - pkg/usermem/usermem.go:171 + - pkg/usermem/usermem.go:170 + - runsc/boot/compat.go:56 + - test/cmd/test_app/fds.go:171 + - test/iptables/filter_output.go:251 + - test/packetimpact/testbench/connections.go:77 + - tools/bigquery/bigquery.go:106 + - tools/checkescape/test1/test1.go:108 + - tools/checkescape/test1/test1.go:122 + - tools/checkescape/test1/test1.go:137 + - tools/checkescape/test1/test1.go:151 + - tools/checkescape/test1/test1.go:170 + - tools/checkescape/test1/test1.go:39 + - tools/checkescape/test1/test1.go:45 + - tools/checkescape/test1/test1.go:50 + - tools/checkescape/test1/test1.go:64 + - tools/checkescape/test1/test1.go:80 + - tools/checkescape/test1/test1.go:94 +analyzers: + asmdecl: + external: # Enabled. + assign: + external: + exclude: + - gazelle/walk/walk.go + atomic: + external: # Enabled. + bools: + external: # Enabled. + buildtag: + external: # Enabled. + cgocall: + external: # Enabled. + shadow: # Disable for now. + generated: + exclude: [".*"] + internal: + exclude: [".*"] + composites: # Disable for now. + generated: + exclude: [".*"] + internal: + exclude: [".*"] + errorsas: + external: # Enabled. + httpresponse: + external: # Enabled. + loopclosure: + external: # Enabled. + nilfunc: + external: # Enabled. + nilness: + internal: + exclude: + - pkg/sentry/platform/kvm/kvm_test.go # Intentional. + - tools/bigquery/bigquery.go # False positive. + printf: + external: # Enabled. + shift: + external: # Enabled. + stringintconv: + external: + exclude: + - ".*protobuf/.*.go" # Bad conversions. + - ".*flate/huffman_bit_writer.go" # Bad conversion. + # Runtime internal violations. + - ".*reflect/value.go" + - ".*encoding/xml/xml.go" + - ".*runtime/pprof/internal/profile/proto.go" + - ".*fmt/scan.go" + - ".*go/types/conversions.go" + - ".*golang.org/x/net/dns/dnsmessage/message.go" + tests: + external: # Enabled. + unmarshal: + external: # Enabled. + unreachable: + external: # Enabled. + unsafeptr: + internal: + exclude: + - ".*_test.go" # Exclude tests. + - "pkg/flipcall/.*_unsafe.go" # Special case. + - pkg/gohacks/gohacks_unsafe.go # Special case. + - pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go # Special case. + - pkg/sentry/platform/kvm/bluepill_unsafe.go # Special case. + - pkg/sentry/platform/kvm/machine_unsafe.go # Special case. + - pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go # Special case. + - pkg/sentry/platform/safecopy/safecopy_unsafe.go # Special case. + - pkg/sentry/vfs/mount_unsafe.go # Special case. + - pkg/state/decode_unsafe.go # Special case. + unusedresult: + external: # Enabled. + checkescape: + external: # Enabled. diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 4a26e28de..a0654df2f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -55,6 +55,8 @@ go_library( "sched.go", "seccomp.go", "sem.go", + "sem_amd64.go", + "sem_arm64.go", "shm.go", "signal.go", "signalfd.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 7df02dd6d..006b5a525 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,6 +121,9 @@ const ( // Constants from uapi/linux/fsverity.h. const ( + FS_VERITY_HASH_ALG_SHA256 = 1 + FS_VERITY_HASH_ALG_SHA512 = 2 + FS_IOC_ENABLE_VERITY = 1082156677 FS_IOC_MEASURE_VERITY = 3221513862 ) diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 487a626cc..1b2f76c0b 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -34,18 +34,6 @@ const ( const SEM_UNDO = 0x1000 -// SemidDS is equivalent to struct semid64_ds. -// -// +marshal -type SemidDS struct { - SemPerm IPCPerm - SemOTime TimeT - SemCTime TimeT - SemNSems uint64 - unused3 uint64 - unused4 uint64 -} - // Sembuf is equivalent to struct sembuf. // // +marshal slice:SembufSlice diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go new file mode 100644 index 000000000..ab980cb4f --- /dev/null +++ b/pkg/abi/linux/sem_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: arch/x86/include/uapi/asm/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + unused1 uint64 + SemCTime TimeT + unused2 uint64 + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go new file mode 100644 index 000000000..521468fb1 --- /dev/null +++ b/pkg/abi/linux/sem_arm64.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: include/uapi/asm-generic/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + SemCTime TimeT + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index 069d0395d..6d1e65cb1 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error { case B: w.WriteString("1") default: - return fmt.Errorf("Invalid BPF LD size: %v", inst) + return fmt.Errorf("invalid BPF LD size: %v", inst) } return nil } diff --git a/pkg/context/context.go b/pkg/context/context.go index 2613bc752..f3031fc60 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()} func Background() Context { return bgContext } + +// WithValue returns a copy of parent in which the value associated with key is +// val. +func WithValue(parent Context, key, val interface{}) Context { + return &withValue{ + Context: parent, + key: key, + val: val, + } +} + +type withValue struct { + Context + key interface{} + val interface{} +} + +// Value implements Context.Value. +func (ctx *withValue) Value(key interface{}) interface{} { + if key == ctx.key { + return ctx.val + } + return ctx.Context.Value(key) +} diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index a8fcb2e19..501a9ef21 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -6,12 +6,18 @@ go_library( name = "merkletree", srcs = ["merkletree.go"], visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) go_test( name = "merkletree_test", srcs = ["merkletree_test.go"], library = ":merkletree", - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d8227b8bd..e0a9e56c5 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -18,21 +18,32 @@ package merkletree import ( "bytes" "crypto/sha256" + "crypto/sha512" "fmt" "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) const ( // sha256DigestSize specifies the digest size of a SHA256 hash. sha256DigestSize = 32 + // sha512DigestSize specifies the digest size of a SHA512 hash. + sha512DigestSize = 64 ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). -func DigestSize() int { - return sha256DigestSize +// TODO(b/156980949): Allow config SHA384. +func DigestSize(hashAlgorithm int) int { + switch hashAlgorithm { + case linux.FS_VERITY_HASH_ALG_SHA256: + return sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + return sha512DigestSize + default: + return -1 + } } // Layout defines the scale of a Merkle tree. @@ -51,11 +62,19 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { +func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ blockSize: usermem.PageSize, - // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, + } + + // TODO(b/156980949): Allow config SHA384. + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + layout.digestSize = sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + layout.digestSize = sha512DigestSize + default: + return Layout{}, fmt.Errorf("unexpected hash algorithms") } // treeStart is the offset (in bytes) of the first level of the tree in @@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { } layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) - return layout + return layout, nil } // hashesPerBlock() returns the number of digests in each block. For example, @@ -128,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 { // meatadata. type VerityDescriptor struct { Name string + FileSize int64 Mode uint32 UID uint32 GID uint32 @@ -135,16 +155,37 @@ type VerityDescriptor struct { } func (d *VerityDescriptor) String() string { - return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash) + return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash) } // verify generates a hash from d, and compares it with expected. -func (d *VerityDescriptor) verify(expected []byte) error { - h := sha256.Sum256([]byte(d.String())) +func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { + h, err := hashData([]byte(d.String()), hashAlgorithms) + if err != nil { + return err + } if !bytes.Equal(h[:], expected) { return fmt.Errorf("unexpected root hash") } return nil + +} + +// hashData hashes data and returns the result hash based on the hash +// algorithms. +func hashData(data []byte, hashAlgorithms int) ([]byte, error) { + var digest []byte + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + digestArray := sha256.Sum256(data) + digest = digestArray[:] + case linux.FS_VERITY_HASH_ALG_SHA512: + digestArray := sha512.Sum512(data) + digest = digestArray[:] + default: + return nil, fmt.Errorf("unexpected hash algorithms") + } + return digest, nil } // GenerateParams contains the parameters used to generate a Merkle tree. @@ -161,6 +202,8 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // TreeReader is a reader for the Merkle tree. TreeReader io.ReaderAt // TreeWriter is a writer for the Merkle tree. @@ -176,7 +219,10 @@ type GenerateParams struct { // Generate returns a hash of a VerityDescriptor, which contains the file // metadata and the hash from file content. func Generate(params *GenerateParams) ([]byte, error) { - layout := InitLayout(params.Size, params.DataAndTreeInSameFile) + layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return nil, err + } numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize @@ -218,10 +264,13 @@ func Generate(params *GenerateParams) ([]byte, error) { return nil, err } // Hash the bytes in buf. - digest := sha256.Sum256(buf) + digest, err := hashData(buf, params.HashAlgorithms) + if err != nil { + return nil, err + } if level == layout.rootLevel() { - root = digest[:] + root = digest } // Write the generated hash to the end of the tree file. @@ -241,13 +290,13 @@ func Generate(params *GenerateParams) ([]byte, error) { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, RootHash: root, } - ret := sha256.Sum256([]byte(descriptor.String())) - return ret[:], nil + return hashData([]byte(descriptor.String()), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -269,6 +318,8 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // ReadOffset is the offset of the data range to be verified. ReadOffset int64 // ReadSize is the size of the data range to be verified. @@ -293,12 +344,13 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, RootHash: root, } - return descriptor.verify(params.Expected) + return descriptor.verify(params.Expected, params.HashAlgorithms) } // Verify verifies the content read from data with offset. The content is @@ -313,7 +365,10 @@ func Verify(params *VerifyParams) (int64, error) { if params.ReadSize < 0 { return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) } - layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return 0, err + } if params.ReadSize == 0 { return 0, verifyMetadata(params, &layout) } @@ -349,12 +404,13 @@ func Verify(params *VerifyParams) (int64, error) { } } descriptor := VerityDescriptor{ - Name: params.Name, - Mode: params.Mode, - UID: params.UID, - GID: params.GID, + Name: params.Name, + FileSize: params.Size, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil { + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err } @@ -395,7 +451,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -406,8 +462,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, for level := 0; level < layout.numLevels(); level++ { // Calculate hash. if level == 0 { - digestArray := sha256.Sum256(dataBlock) - digest = digestArray[:] + h, err := hashData(dataBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } else { // Read a block in previous level that contains the // hash we just generated, and generate a next level @@ -415,8 +474,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil { return err } - digestArray := sha256.Sum256(treeBlock) - digest = digestArray[:] + h, err := hashData(treeBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } // Read the digest for the current block and store in @@ -434,5 +496,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. descriptor.RootHash = digest - return descriptor.verify(expected) + return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index e1350ebda..405204d94 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -22,54 +22,114 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) func TestLayout(t *testing.T) { testCases := []struct { dataSize int64 + hashAlgorithms int dataAndTreeInSameFile bool + expectedDigestSize int64 expectedLevelOffset []int64 }{ { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0}, }, { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedDigestSize: 32, + expectedLevelOffset: []int64{usermem.PageSize}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, + expectedDigestSize: 64, expectedLevelOffset: []int64{usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + }, + { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + }, + { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) + if err != nil { + t.Fatalf("Failed to InitLayout: %v", err) + } if l.blockSize != int64(usermem.PageSize) { t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if l.digestSize != sha256DigestSize { + if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { @@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - expectedHash []byte + data []byte + hashAlgorithms int + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147}, + }, + { + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171}, }, { - data: []byte{'a'}, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246}, + }, + { + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84}, }, } @@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) { Mode: defaultMode, UID: defaultUID, GID: defaultGID, + HashAlgorithms: tc.hashAlgorithms, TreeReader: &tree, TreeWriter: &tree, DataAndTreeInSameFile: dataAndTreeInSameFile, @@ -189,6 +275,7 @@ func TestVerify(t *testing.T) { // fail, otherwise Verify should still succeed. modifyByte int64 modifyName bool + modifySize bool modifyMode bool modifyUID bool modifyGID bool @@ -237,6 +324,15 @@ func TestVerify(t *testing.T) { modifyName: true, shouldSucceed: false, }, + // Modified size should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifySize: true, + shouldSucceed: false, + }, // Modified mode should fail verification. { dataSize: usermem.PageSize, @@ -348,77 +444,84 @@ func TestVerify(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: tc.dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: tc.verifyStart, + ReadSize: tc.verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + if tc.modifyName { + verifyParams.Name = defaultName + "abc" } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + if tc.modifySize { + verifyParams.Size-- } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") + if tc.modifyMode { + verifyParams.Mode = defaultMode + 1 } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") + if tc.modifyUID { + verifyParams.UID = defaultUID + 1 + } + if tc.modifyGID { + verifyParams.GID = defaultGID + 1 + } + if tc.shouldSucceed { + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&verifyParams); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } } } } @@ -435,87 +538,91 @@ func TestVerifyRandom(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + var buf bytes.Buffer + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: start, + ReadSize: size, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - // Verify that modified metadata should fail verification. - buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verify succeeded for modified metadata, expect failure") - } + // Verify that modified metadata should fail verification. + buf.Reset() + verifyParams.Name = defaultName + "abc" + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verify succeeded for modified metadata, expect failure") + } - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + verifyParams.File = bytes.NewReader(data) + verifyParams.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verification succeeded for modified data, expect failure") + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verification succeeded for modified data, expect failure") + } } } } diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 699ea8ac3..6992e1de8 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -319,7 +319,8 @@ func makeStackKey(pcs []uintptr) stackKey { return key } -func recordStack() []uintptr { +// RecordStack constructs and returns the PCs on the current stack. +func RecordStack() []uintptr { pcs := make([]uintptr, maxStackFrames) n := runtime.Callers(1, pcs) if n == 0 { @@ -342,7 +343,8 @@ func recordStack() []uintptr { return v } -func formatStack(pcs []uintptr) string { +// FormatStack converts the given stack into a readable format. +func FormatStack(pcs []uintptr) string { frames := runtime.CallersFrames(pcs) var trace bytes.Buffer for { @@ -367,7 +369,7 @@ func (r *AtomicRefCount) finalize() { if n := r.ReadRefs(); n != 0 { msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n) if len(r.stack) != 0 { - msg += ":\nCaller:\n" + formatStack(r.stack) + msg += ":\nCaller:\n" + FormatStack(r.stack) } else { msg += " (enable trace logging to debug)" } @@ -392,7 +394,7 @@ func (r *AtomicRefCount) EnableLeakCheck(name string) { case NoLeakChecking: return case LeaksLogTraces: - r.stack = recordStack() + r.stack = RecordStack() } r.name = name runtime.SetFinalizer(r, (*AtomicRefCount).finalize) diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD index 577b827a5..bfa1daa10 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -8,6 +8,9 @@ go_template( srcs = [ "refs_template.go", ], + opt_consts = [ + "logTrace", + ], types = [ "T", ], @@ -19,8 +22,16 @@ go_template( ) go_library( - name = "refs_vfs2", - srcs = ["refs.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/context"], + name = "refsvfs2", + srcs = [ + "refs.go", + "refs_map.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/context", + "//pkg/log", + "//pkg/refs", + "//pkg/sync", + ], ) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go index 99a074e96..ef8beb659 100644 --- a/pkg/refs_vfs2/refs.go +++ b/pkg/refsvfs2/refs.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_vfs2 defines an interface for a reference-counted object. -package refs_vfs2 +// Package refsvfs2 defines an interface for a reference-counted object. +package refsvfs2 import ( "gvisor.dev/gvisor/pkg/context" diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go new file mode 100644 index 000000000..9fbc5466f --- /dev/null +++ b/pkg/refsvfs2/refs_map.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package refsvfs2 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" +) + +var ( + // liveObjects is a global map of reference-counted objects. Objects are + // inserted when leak check is enabled, and they are removed when they are + // destroyed. It is protected by liveObjectsMu. + liveObjects map[CheckedObject]struct{} + liveObjectsMu sync.Mutex +) + +// CheckedObject represents a reference-counted object with an informative +// leak detection message. +type CheckedObject interface { + // RefType is the type of the reference-counted object. + RefType() string + + // LeakMessage supplies a warning to be printed upon leak detection. + LeakMessage() string + + // LogRefs indicates whether reference-related events should be logged. + LogRefs() bool +} + +func init() { + liveObjects = make(map[CheckedObject]struct{}) +} + +// leakCheckEnabled returns whether leak checking is enabled. The following +// functions should only be called if it returns true. +func leakCheckEnabled() bool { + return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking +} + +// Register adds obj to the live object map. +func Register(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + if _, ok := liveObjects[obj]; ok { + panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj)) + } + liveObjects[obj] = struct{}{} + liveObjectsMu.Unlock() + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "registered") + } + } +} + +// Unregister removes obj from the live object map. +func Unregister(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + if _, ok := liveObjects[obj]; !ok { + panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) + } + delete(liveObjects, obj) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "unregistered") + } + } +} + +// LogIncRef logs a reference increment. +func LogIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("IncRef to %d", refs)) + } +} + +// LogTryIncRef logs a successful TryIncRef call. +func LogTryIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs)) + } +} + +// LogDecRef logs a reference decrement. +func LogDecRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("DecRef to %d", refs)) + } +} + +// logEvent logs a message for the given reference-counted object. +// +// obj.LogRefs() should be checked before calling logEvent, in order to avoid +// calling any text processing needed to evaluate msg. +func logEvent(obj CheckedObject, msg string) { + log.Infof("[%s %p] %s:", obj.RefType(), obj, msg) + log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack())) +} + +// DoLeakCheck iterates through the live object map and logs a message for each +// object. It is called once no reference-counted objects should be reachable +// anymore, at which point anything left in the map is considered a leak. +func DoLeakCheck() { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + log.Warningf("Leak checking detected %d leaked objects:", leaked) + for obj := range liveObjects { + log.Warningf(obj.LeakMessage()) + } + } + } +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index d9b552896..8f50b4ee6 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -21,20 +21,24 @@ package refs_template import ( "fmt" - "runtime" "sync/atomic" - "gvisor.dev/gvisor/pkg/log" - refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" ) +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const enableLogging = false + // T is the type of the reference counted object. It is only used to customize // debug output when leak checking. type T interface{} -// ownerType is used to customize logging. Note that we use a pointer to T so -// that we do not copy the entire object when passed as a format parameter. -var ownerType *T +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. @@ -42,11 +46,6 @@ var ownerType *T // Note that the number of references is actually refCount + 1 so that a default // zero-value Refs object contains one reference. // -// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in -// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. -// This will allow us to add stack trace information to the leak messages -// without growing the size of Refs. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -59,24 +58,24 @@ type Refs struct { refCount int64 } -func (r *Refs) finalize() { - var note string - switch refs_vfs1.GetLeakMode() { - case refs_vfs1.NoLeakChecking: - return - case refs_vfs1.UninitializedLeakChecking: - note = "(Leak checker uninitialized): " - } - if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) - } +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *Refs) RefType() string { + return fmt.Sprintf("%T", obj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *Refs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) } -// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *Refs) LogRefs() bool { + return enableLogging +} + +// EnableLeakCheck enables reference leak checking on r. func (r *Refs) EnableLeakCheck() { - if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*Refs).finalize) - } + refsvfs2.Register(r) } // ReadRefs returns the current number of references. The returned count is @@ -90,8 +89,10 @@ func (r *Refs) ReadRefs() int64 { // //go:nosplit func (r *Refs) IncRef() { - if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + v := atomic.AddInt64(&r.refCount, 1) + refsvfs2.LogIncRef(r, v+1) + if v <= 0 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) } } @@ -104,15 +105,15 @@ func (r *Refs) IncRef() { //go:nosplit func (r *Refs) TryIncRef() bool { const speculativeRef = 1 << 32 - v := atomic.AddInt64(&r.refCount, speculativeRef) - if int32(v) < 0 { + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 { // This object has already been freed. atomic.AddInt64(&r.refCount, -speculativeRef) return false } // Turn into a real reference. - atomic.AddInt64(&r.refCount, -speculativeRef+1) + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + refsvfs2.LogTryIncRef(r, v+1) return true } @@ -129,14 +130,23 @@ func (r *Refs) TryIncRef() bool { // //go:nosplit func (r *Refs) DecRef(destroy func()) { - switch v := atomic.AddInt64(&r.refCount, -1); { + v := atomic.AddInt64(&r.refCount, -1) + refsvfs2.LogDecRef(r, v+1) + switch { case v < -1: - panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) case v == -1: + refsvfs2.Unregister(r) // Call the destructor. if destroy != nil { destroy() } } } + +func (r *Refs) afterLoad() { + if r.ReadRefs() > 0 { + r.EnableLeakCheck() + } +} diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index 41feeffe3..d800f2c85 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { s.Kernel.Kill(kernel.ExitStatus{}) }, } - return saveOpts.Save(s.Kernel, s.Watchdog) + return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog) } diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..ff5d49fbd 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -39,6 +39,8 @@ const ( ) // tunDevice implements vfs.Device for /dev/net/tun. +// +// +stateify savable type tunDevice struct{} // Open implements vfs.Device.Open. @@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt } // tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +// +// +stateify savable type tunFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 1390a9a7f..4468f5dd2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() { f.mappings = make(map[uint64]mapping) } +// IsInited returns true if f.Init() has been called. This is used when +// restoring a checkpoint that contains a HostFileMapper that may or may not +// have been initialized. +func (f *HostFileMapper) IsInited() bool { + return f.refs != nil +} + // NewHostFileMapper returns an initialized HostFileMapper allocated on the // heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 3c66dc3c2..6b3627813 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // maxFilenameLen is the maximum length of a filename. This is dictated by 9P's @@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, } // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Wrap the fileOps with our Fifo. iops := &fifo{ diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index e555672ad..52061175f 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { } // GetFile implements fs.InodeOperations.GetFile. -func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true - return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil + return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil } // +stateify savable diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 22d658acf..450044c9c 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "gid_map": newGIDMap(t, msrc), "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), + "mem": newMem(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "net": newNetDir(t, msrc), @@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(t, d, msrc, fs.SpecialDirectory, t) } +// memData implements fs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memData struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// memDataFile implements fs.FileOperations for /proc/[pid]/mem. +// +// +stateify savable +type memDataFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + inode := &memData{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, inode, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (m *memData) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, m.t, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(m.t); err != nil { + return nil, err + } + // Enable random access reads + flags.Pread = true + return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + mm, err := getTaskMM(m.t) + if err != nil { + return 0, nil + } + defer mm.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + // mapsData implements seqfile.SeqSource for /proc/[pid]/maps. // // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index fc0498f17..d6c65301c 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Continue. seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil @@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // Write to that memory as usual. seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 998b697ca..cf4ed5de0 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -336,7 +336,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 84baaac66..6af3c3781 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "root_inode_refs.go", package = "devpts", prefix = "rootInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "rootInode", }, @@ -33,6 +33,7 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d5c5aaa8c..346cca558 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir } fstype.initOnce.Do(func() { - fs, root, err := fstype.newFilesystem(vfsObj, creds) + fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds) if err != nil { fstype.initErr = err return @@ -93,7 +93,7 @@ type filesystem struct { // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { +func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -108,19 +108,19 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds root := &rootInode{ replicas: make(map[uint32]*replicaInode), } - root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) root.EnableLeakCheck() var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) // Construct the pts master inode and dentry. Linux always uses inode // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx. master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) // Add the master as a child of the root. links := root.OrderedChildren.Populate(map[string]kernfs.Inode{ @@ -170,7 +170,7 @@ type rootInode struct { var _ kernfs.Inode = (*rootInode)(nil) // allocateTerminal creates a new Terminal and installs a pts node for it. -func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) { +func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) { i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { @@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) i.replicas[idx] = replica return t, nil @@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro } // IterDirents implements kernfs.Inode.IterDirents. -func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() + i.InodeAttrs.TouchAtime(ctx, mnt) ids := make([]int, 0, len(i.replicas)) for id := range i.replicas { ids = append(ids, int(id)) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e6b0e81cf..ae95fdd08 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -100,10 +100,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index fda30fb93..e91fa26a4 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - t, err := mi.root.allocateTerminal(rp.Credentials()) + t, err := mi.root.allocateTerminal(ctx, rp.Credentials()) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index 01bbee5ad..e49a04c1b 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -4,7 +4,10 @@ licenses(["notice"]) go_library( name = "devtmpfs", - srcs = ["devtmpfs.go"], + srcs = [ + "devtmpfs.go", + "save_restore.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go new file mode 100644 index 000000000..28832d850 --- /dev/null +++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devtmpfs + +// afterLoad is invoked by stateify. +func (fst *FilesystemType) afterLoad() { + if fst.fs != nil { + // Ensure that we don't create another filesystem. + fst.initOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 1c27ad700..5b29f2358 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -43,7 +43,7 @@ type EventFileDescription struct { // queue is used to notify interested parties when the event object // becomes readable or writable. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // mu protects the fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 045d7ab08..2158b1bbc 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "inode_refs.go", package = "fuse", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -49,6 +49,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 5986133e9..95c475a65 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F readPayload.MarshalUnsafe(outBuf[outHdrLen:]) outIOseq := usermem.BytesIOSequence(outBuf) - n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) if err != nil { t.Fatalf("Write failed :%v", err) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index e39df21c6..6de416da0 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newRootInode(creds, fsopts.rootMode) + root := fs.newRoot(ctx, creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } @@ -284,21 +284,21 @@ type inode struct { link string } -func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { i := &inode{fs: fs, nodeID: 1} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() var d kernfs.Dentry - d.Init(&fs.Filesystem, i) + d.InitRoot(&fs.Filesystem, i) return &d } -func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { +func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { i := &inode{fs: fs, nodeID: nodeID} creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} - i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -424,7 +424,7 @@ func (i *inode) Keep() bool { } // IterDirents implements kernfs.Inode.IterDirents. -func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } @@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { return nil, syserror.EIO } - child := i.fs.newInode(out.NodeID, out.Attr) + child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil } @@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return linux.FUSEAttr{}, err @@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 625d1547f..2d396e84c 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u // May need to update the signature. i := fd.inode() - // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount()) // Reached EOF. if sizeRead < size { @@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, Flags: fd.statusFlags(), } + inode := fd.inode() var written uint32 // This loop is intended for fragmented write where the bytes to write is @@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) if err != nil { return 0, err } @@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, break } } + inode.InodeAttrs.TouchCMtime(ctx) return written, nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index ad0afc41b..4c3e9acf8 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "save_restore.go", "socket.go", "special_file.go", "symlink.go", @@ -53,6 +54,7 @@ go_library( "//pkg/log", "//pkg/p9", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", @@ -70,6 +72,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 18c884b59..ce1b2a390 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -16,16 +16,17 @@ package gofer import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { child := &dentry{ refs: 1, // held by d fs: d.fs, - ino: d.fs.nextSyntheticIno(), + ino: d.fs.nextIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -100,6 +101,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { hostFD: -1, nlink: uint32(2), } + refsvfs2.Register(child) switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. @@ -235,7 +237,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: uint64(inoFromPath(p9d.QID.Path)), + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 94d96261b..bbb01148b 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -30,12 +30,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current syncable dentries and special files. + // Snapshot current syncable dentries and special file FDs. fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { @@ -53,22 +52,28 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error - // Sync regular files. + // Sync syncable dentries. for _, d := range ds { - err := d.syncCachedFile(ctx) + err := d.syncCachedFile(ctx, true /* forFilesystemSync */) d.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } } } // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - err := sffd.Sync(ctx) + err := sffd.sync(ctx, true /* forFilesystemSync */) sffd.vfsfd.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } } } @@ -229,7 +234,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir return nil, err } if child != nil { - if !file.isNil() && inoFromPath(qid.Path) == child.ino { + if !file.isNil() && qid.Path == child.qidPath { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) child.updateFromP9AttrsLocked(attrMask, &attr) @@ -256,7 +261,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // treat their invalidation as deletion. child.setDeleted() parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() parent.dirents = nil } *ds = appendDentry(*ds, child) @@ -366,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if len(name) > maxFilenameLen { return syserror.ENAMETOOLONG } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.isDeleted() { return syserror.ENOENT } @@ -383,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if createInSyntheticDir == nil { return syserror.EPERM } @@ -402,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil && child.isSynthetic() { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // The existence of a non-synthetic dentry at name would be inconclusive // because the file it represents may have been deleted from the remote // filesystem, so we would need to make an RPC to revalidate the dentry. @@ -422,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. if err := createInRemoteDir(parent, name, &ds); err != nil { @@ -625,7 +636,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.setDeleted() if child.isSynthetic() { parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() } ds = appendDentry(ds, child) } @@ -836,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode: opts.Mode, kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) return nil } @@ -1355,7 +1366,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replaced.setDeleted() if replaced.isSynthetic() { newParent.syntheticChildren-- - replaced.decRefLocked() + replaced.decRefNoCaching() } ds = appendDentry(ds, replaced) } @@ -1364,7 +1375,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { - oldParent.decRefLocked() + oldParent.decRefNoCaching() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { @@ -1512,7 +1523,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath d.IncRef() return &endpoint{ dentry: d, - file: d.file.file, path: opts.Addr, }, nil } @@ -1591,7 +1601,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } - -func (fs *filesystem) nextSyntheticIno() inodeNumber { - return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index f1dad1b08..6f82ce61b 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -26,6 +26,9 @@ // *** "memmap.Mappable locks taken by Translate" below this point // dentry.handleMu // dentry.dataMu +// filesystem.inoMu +// specialFileFD.mu +// specialFileFD.bufMu // // Locking dentry.dirMu in multiple dentries requires that either ancestor // dentries are locked before descendant dentries, or that filesystem.renameMu @@ -36,7 +39,6 @@ import ( "fmt" "strconv" "strings" - "sync" "sync/atomic" "syscall" @@ -44,6 +46,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -53,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -81,7 +86,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + client *p9.Client `state:"nosave"` // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -89,6 +94,9 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // root is the root dentry. root is immutable. + root *dentry + // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -103,39 +111,35 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. + // cachedDentriesLen is the number of dentries in cachedDentries. These fields + // are protected by renameMu. cachedDentries dentryList cachedDentriesLen uint64 - // syncableDentries contains all dentries in this filesystem for which - // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. - // These fields are protected by syncMu. + // syncableDentries contains all non-synthetic dentries. specialFileFDs + // contains all open specialFileFDs. These fields are protected by syncMu. syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} - // syntheticSeq stores a counter to used to generate unique inodeNumber for - // synthetic dentries. - syntheticSeq uint64 -} + // inoByQIDPath maps previously-observed QID.Paths to inode numbers + // assigned to those paths. inoByQIDPath is not preserved across + // checkpoint/restore because QIDs may be reused between different gofer + // processes, so QIDs may be repeated for different files across + // checkpoint/restore. inoByQIDPath is protected by inoMu. + inoMu sync.Mutex `state:"nosave"` + inoByQIDPath map[uint64]uint64 `state:"nosave"` -// inodeNumber represents inode number reported in Dirent.Ino. For regular -// dentries, it comes from QID.Path from the 9P server. Synthetic dentries -// have have their inodeNumber generated sequentially, with the MSB reserved to -// prevent conflicts with regular dentries. -// -// +stateify savable -type inodeNumber uint64 + // lastIno is the last inode number assigned to a file. lastIno is accessed + // using atomic memory operations. + lastIno uint64 -// Reserve MSB for synthetic mounts. -const syntheticInoMask = uint64(1) << 63 + // savedDentryRW records open read/write handles during save/restore. + savedDentryRW map[*dentry]savedDentryRW -func inoFromPath(path uint64) inodeNumber { - if path&syntheticInoMask != 0 { - log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) - } - return inodeNumber(path &^ syntheticInoMask) + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // +stateify savable @@ -149,8 +153,7 @@ type filesystemOptions struct { msize uint32 version string - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. + // maxCachedDentries is the maximum size of filesystem.cachedDentries. maxCachedDentries uint64 // If forcePageCache is true, host FDs may not be used for application @@ -247,6 +250,10 @@ const ( // // +stateify savable type InternalFilesystemOptions struct { + // If UniqueID is non-empty, it is an opaque string used to reassociate the + // filesystem with a new server FD during restoration from checkpoint. + UniqueID string + // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in // which servers can handle only a single client and report failure if that @@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mopts := vfs.GenericParseMountOptions(opts.Data) var fsopts filesystemOptions - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) + fd, err := getFDFromMountOptionsMap(ctx, mopts) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL + return nil, nil, err } - fsopts.fd = rfd + fsopts.fd = fd // Get the attach name. fsopts.aname = "/" @@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // If !ok, iopts being the zero value is correct. - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) + // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } + fs := &filesystem{ + mfp: mfp, + opts: fsopts, + iopts: iopts, + clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + syncableDentries: make(map[*dentry]struct{}), + specialFileFDs: make(map[*specialFileFD]struct{}), + inoByQIDPath: make(map[uint64]uint64), + } + fs.vfsfs.Init(vfsObj, &fstype, fs) - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() + // Connect to the server. + if err := fs.dial(ctx); err != nil { return nil, nil, err } - // Ownership of conn has been transferred to client. // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fsopts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } - // Construct the filesystem object. - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - attachFile.close(ctx) - client.Close() - return nil, nil, err - } - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - iopts: iopts, - client: client, - clock: ktime.RealtimeClockFromContext(ctx), - devMinor: devMinor, - syncableDentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, &fstype, fs) - // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { @@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. root.refs = 2 + fs.root = root return &fs.vfsfs, &root.vfsd, nil } +func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { + // Check that the transport is "fd". + trans, ok := mopts["trans"] + if !ok || trans != "fd" { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + return -1, syserror.EINVAL + } + delete(mopts, "trans") + + // Check that read and write FDs are provided and identical. + rfdstr, ok := mopts["rfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "rfdno") + rfd, err := strconv.Atoi(rfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + return -1, syserror.EINVAL + } + wfdstr, ok := mopts["wfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "wfdno") + wfd, err := strconv.Atoi(wfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + return -1, syserror.EINVAL + } + if rfd != wfd { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) + return -1, syserror.EINVAL + } + return rfd, nil +} + +// Preconditions: fs.client == nil. +func (fs *filesystem) dial(ctx context.Context) error { + // Establish a connection with the server. + conn, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return err + } + + // Perform version negotiation with the server. + ctx.UninterruptibleSleepStart(false) + client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + conn.Close() + return err + } + // Ownership of conn has been transferred to client. + + fs.client = client + return nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { - mf := fs.mfp.MemoryFile() + atomic.StoreInt32(&fs.released, 1) + mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) { // fs. fs.syncMu.Unlock() + // If leak checking is enabled, release all outstanding references in the + // filesystem. We deliberately avoid doing this outside of leak checking; we + // have released all external resources above rather than relying on dentry + // destructors. + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + fs.renameMu.Lock() + fs.root.releaseSyntheticRecursiveLocked(ctx) + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent it from + // being cached/evicted. + fs.root.DecRef(ctx) + } + if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() @@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements +// the reference count on every synthetic dentry. Synthetic dentries have one +// reference for existence that should be dropped during filesystem.Release. +// +// Precondition: d.fs.renameMu is locked. +func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { + if d.isSynthetic() { + d.decRefNoCaching() + d.checkCachingLocked(ctx) + } + if d.isDir() { + var children []*dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + if child != nil { + child.releaseSyntheticRecursiveLocked(ctx) + } + } + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -574,12 +635,15 @@ type dentry struct { // filesystem.renameMu. name string + // qidPath is the p9.QID.Path for this file. qidPath is immutable. + qidPath uint64 + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + file p9file `state:"nosave"` // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -623,12 +687,12 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex `state:"nosave"` - ino inodeNumber // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + metadataMu sync.Mutex `state:"nosave"` + ino uint64 // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 @@ -679,9 +743,9 @@ type dentry struct { // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - hostFD int32 + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + hostFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, + qidPath: qid.Path, file: file, - ino: inoFromPath(qid.Path), + ino: fs.inoFromQIDPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -795,13 +860,28 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d.nlink = uint32(attr.NLink) } d.vfsd.Init(d) - + refsvfs2.Register(d) fs.syncMu.Lock() fs.syncableDentries[d] = struct{}{} fs.syncMu.Unlock() return d, nil } +func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + if ino, ok := fs.inoByQIDPath[qidPath]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByQIDPath[qidPath] = ino + return ino +} + +func (fs *filesystem) nextIno() uint64 { + return atomic.AddUint64(&fs.lastIno, 1) +} + func (d *dentry) isSynthetic() bool { return d.file.isNil() } @@ -853,7 +933,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } -// Preconditions: !d.isSynthetic() +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { // Use d.readFile or d.writeFile, which represent 9P fids that have been // opened, in preference to d.file, which represents a 9P fid that has not. @@ -916,10 +996,10 @@ func (d *dentry) statTo(stat *linux.Statx) { // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime)) + stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = d.fs.devMinor } @@ -967,10 +1047,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // Use client clocks for timestamps. now = d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { - stat.Atime = statxTimestampFromDentry(now) + stat.Atime = linux.NsecToStatxTimestamp(now) } if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { - stat.Mtime = statxTimestampFromDentry(now) + stat.Mtime = linux.NsecToStatxTimestamp(now) } } @@ -1029,11 +1109,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. if stat.Mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } if stat.Mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) @@ -1139,17 +1219,19 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkCachingLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -1157,22 +1239,41 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + if d.decRefNoCaching() == 0 { d.fs.renameMu.Lock() d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") } } -// decRefLocked decrements d's reference count without calling +// decRefNoCaching decrements d's reference count without calling // d.checkCachingLocked, even if d's reference count reaches 0; callers are // responsible for ensuring that d.checkCachingLocked will be called later. -func (d *dentry) decRefLocked() { - if refs := atomic.AddInt64(&d.refs, -1); refs < 0 { - panic("gofer.dentry.decRefLocked() called without holding a reference") +func (d *dentry) decRefNoCaching() int64 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r < 0 { + panic("gofer.dentry.decRefNoCaching() called without holding a reference") } + return r +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "gofer.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -1223,6 +1324,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + return + } if refs > 0 { if d.cached { d.fs.cachedDentries.Remove(d) @@ -1231,10 +1336,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } return } - if refs == -1 { - // Dentry has already been destroyed. - return - } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { @@ -1257,6 +1358,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.watches.Size() > 0 { return } + + if atomic.LoadInt32(&d.fs.released) != 0 { + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) + } + // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1269,33 +1380,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentriesLen++ d.cached = true if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) - d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() - } - victim.destroyLocked(ctx) - } + d.fs.evictCachedDentryLocked(ctx) // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. } } +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * fs.cachedDentriesLen != 0. +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. + } + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } +} + // destroyLocked destroys the dentry. // // Preconditions: @@ -1373,13 +1499,10 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. - if d.parent != nil { - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkCachingLocked(ctx) - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") - } + if d.parent != nil && d.parent.decRefNoCaching() == 0 { + d.parent.checkCachingLocked(ctx) } + refsvfs2.Unregister(d) } func (d *dentry) isDeleted() bool { @@ -1623,6 +1746,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { return nil } +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() + if h.isOpen() { + // Write back dirty pages to the remote file. + d.dataMu.Lock() + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } + } + if err := d.syncRemoteFileLocked(ctx); err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (d is a regular file and was opened for writing). + if d.isRegularFile() && h.isOpen() { + return err + } + ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1650,7 +1800,7 @@ type fileDescription struct { vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + lockLogging sync.Once `state:"nosave"` } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index bfe75dfe4..76f08e252 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -26,12 +26,13 @@ import ( func TestDestroyIdempotent(t *testing.T) { ctx := contexttest.Context(t) fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - syncableDentries: make(map[*dentry]struct{}), + mfp: pgalloc.MemoryFileProviderFromContext(ctx), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. maxCachedDentries: 0, }, + syncableDentries: make(map[*dentry]struct{}), + inoByQIDPath: make(map[uint64]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 7294de7d6..c7bf10007 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { if ok { return nil } - if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { - return err + if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil { + // Another application thread may have opened this pipe for + // writing, succeeded because we previously opened the pipe for + // reading, and subsequently interrupted us for checkpointing (e.g. + // this occurs in mknod tests under cooperative save/restore). In + // this case, our open has to succeed for the checkpoint to include + // a readable FD for the pipe, which is in turn necessary to + // restore the other thread's writable FD for the same pipe + // (otherwise it will get ENXIO). So we have to check + // nonblockingPipeHasWriter() once last time. + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + return sleepErr } } } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f8b19bae7..dc8a890cb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -18,7 +18,6 @@ import ( "fmt" "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx) -} - -func (d *dentry) syncCachedFile(ctx context.Context) error { - d.handleMu.RLock() - defer d.handleMu.RUnlock() - - if h := d.writeHandleLocked(); h.isOpen() { - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) - d.dataMu.Unlock() - if err != nil { - return err - } - } - return d.syncRemoteFileLocked(ctx) + return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -913,7 +897,7 @@ type dentryPlatformFile struct { hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go new file mode 100644 index 000000000..17849dcc0 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -0,0 +1,329 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type saveRestoreContextID int + +const ( + // CtxRestoreServerFDMap is a Context.Value key for a map[string]int + // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID) + // to host FDs. + CtxRestoreServerFDMap saveRestoreContextID = iota +) + +// +stateify savable +type savedDentryRW struct { + read bool + write bool +} + +// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave. +func (fs *filesystem) PrepareSave(ctx context.Context) error { + if len(fs.iopts.UniqueID) == 0 { + return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved") + } + + // Purge cached dentries, which may not be reopenable after restore due to + // permission changes. + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // Buffer pipe data so that it's available for reading after restore. (This + // is a legacy VFS1 feature.) + fs.syncMu.Lock() + for sffd := range fs.specialFileFDs { + if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() { + if err := sffd.savePipeData(ctx); err != nil { + fs.syncMu.Unlock() + return err + } + } + } + fs.syncMu.Unlock() + + // Flush local state to the remote filesystem. + if err := fs.Sync(ctx); err != nil { + return err + } + + fs.savedDentryRW = make(map[*dentry]savedDentryRW) + return fs.root.prepareSaveRecursive(ctx) +} + +// Preconditions: +// * fd represents a pipe. +// * fd is readable. +func (fd *specialFileFD) savePipeData(ctx context.Context) error { + fd.bufMu.Lock() + defer fd.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) + if n != 0 { + fd.buf = append(fd.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syserror.EAGAIN { + break + } + return err + } + } + if len(fd.buf) != 0 { + atomic.StoreUint32(&fd.haveBuf, 1) + } + return nil +} + +func (d *dentry) prepareSaveRecursive(ctx context.Context) error { + if d.isRegularFile() && !d.cachedMetadataAuthoritative() { + // Get updated metadata for d in case we need to perform metadata + // validation during restore. + if err := d.updateFromGetattr(ctx); err != nil { + return err + } + } + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } + } + d.dirMu.Lock() + defer d.dirMu.Unlock() + for _, child := range d.children { + if child != nil { + if err := child.prepareSaveRecursive(ctx); err != nil { + return err + } + } + } + return nil +} + +// beforeSave is invoked by stateify. +func (d *dentry) beforeSave() { + if d.vfsd.IsDead() { + panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d))) + } +} + +// afterLoad is invoked by stateify. +func (d *dentry) afterLoad() { + d.hostFD = -1 + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (d *dentryPlatformFile) afterLoad() { + if d.hostFileMapper.IsInited() { + // Ensure that we don't call d.hostFileMapper.Init() again. + d.hostFileMapperInitOnce.Do(func() {}) + } +} + +// afterLoad is invoked by stateify. +func (fd *specialFileFD) afterLoad() { + fd.handle.fd = -1 +} + +// CompleteRestore implements +// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore. +func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error { + fdmapv := ctx.Value(CtxRestoreServerFDMap) + if fdmapv == nil { + return fmt.Errorf("no server FD map available") + } + fdmap := fdmapv.(map[string]int) + fd, ok := fdmap[fs.iopts.UniqueID] + if !ok { + return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) + } + fs.opts.fd = fd + if err := fs.dial(ctx); err != nil { + return err + } + fs.inoByQIDPath = make(map[uint64]uint64) + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } + + // Restore remaining dentries. + if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil { + return err + } + + // Re-open handles for specialFileFDs. Unlike the initial open + // (dentry.openSpecialFile()), pipes are always opened without blocking; + // non-readable pipe FDs are opened last to ensure that they don't get + // ENXIO if another specialFileFD represents the read end of the same pipe. + // This is consistent with VFS1. + haveWriteOnlyPipes := false + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + haveWriteOnlyPipes = true + continue + } + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + if haveWriteOnlyPipes { + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + } + } + + // Discard state only required during restore. + fs.savedDentryRW = nil + + return nil +} + +func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error { + d.file = file + + // Gofers do not preserve QID across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking QIDs. + // + // - We need to associate the new QID.Path with the existing d.ino. + d.qidPath = qid.Path + d.fs.inoMu.Lock() + d.fs.inoByQIDPath[qid.Path] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if !attrMask.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != attr.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if !attrMask.MTime { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromP9AttrsLocked(attrMask, attr) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + +// Preconditions: d is not synthetic. +func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + for _, child := range d.children { + if child == nil { + continue + } + if _, ok := d.fs.syncableDentries[child]; !ok { + // child is synthetic. + continue + } + if err := child.restoreRecursive(ctx, opts); err != nil { + return err + } + } + return nil +} + +// Preconditions: d is not synthetic (but note that since this function +// restores d.file, d.file.isNil() is always true at this point, so this can +// only be detected by checking filesystem.syncableDentries). d.parent has been +// restored. +func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } + return d.restoreDescendantsRecursive(ctx, opts) +} + +func (fd *specialFileFD) completeRestore(ctx context.Context) error { + d := fd.dentry() + h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + if err != nil { + return err + } + fd.handle = h + + ftype := d.fileType() + fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0 + if fd.haveQueue { + if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 326b940a7..a21199eac 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -42,9 +42,6 @@ type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry - // file is the p9 file that contains a single unopened fid. - file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - // path is the sentry path where this endpoint is bound. path string } @@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect } func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.file.Connect(flags) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { return nil, syserr.ErrConnectionRefused } @@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) return nil, serr } return c, nil diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 71581736c..625400c0b 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -15,7 +15,6 @@ package gofer import ( - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -40,7 +40,7 @@ type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + handle handle `state:"nosave"` // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -54,12 +54,20 @@ type specialFileFD struct { // haveQueue is true if this file description represents a file for which // queue may send I/O readiness events. haveQueue is immutable. - haveQueue bool + haveQueue bool `state:"nosave"` queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex `state:"nosave"` off int64 + + // If haveBuf is non-zero, this FD represents a pipe, and buf contains data + // read from the pipe from previous calls to specialFileFD.savePipeData(). + // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic + // memory operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { @@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } return nil, err } + d.fs.syncMu.Lock() + d.fs.specialFileFDs[fd] = struct{}{} + d.fs.syncMu.Unlock() return fd, nil } @@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, syserror.EOPNOTSUPP } - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } + + bufN := int64(0) + if atomic.LoadUint32(&fd.haveBuf) != 0 { + var err error + fd.bufMu.Lock() + if len(fd.buf) != 0 { + var n int + n, err = dst.CopyOut(ctx, fd.buf) + dst = dst.DropFirst(n) + fd.buf = fd.buf[n:] + if len(fd.buf) == 0 { + atomic.StoreUint32(&fd.haveBuf, 0) + fd.buf = nil + } + bufN = int64(n) + if offset >= 0 { + offset += bufN + } + } + fd.bufMu.Unlock() + if err != nil { + return bufN, err + } + } + + // Going through dst.CopyOutFrom() would hold MM locks around file + // operations of unknown duration. For regularFileFD, doing so is necessary + // to support mmap due to lock ordering; MM locks precede dentry.dataMu. + // That doesn't hold here since specialFileFD doesn't client-cache data. + // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } if n == 0 { - return 0, err + return bufN, err } if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr + return bufN + int64(cp), cperr } - return int64(n), err + return bufN + int64(n), err } // Read implements vfs.FileDescriptionImpl.Read. @@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() - // If the regular file fd was opened with O_APPEND, make sure the file size - // is updated. There is a possible race here if size is modified externally - // after metadata cache is updated. - if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, offset, err + if fd.isRegularFile { + // If the regular file fd was opened with O_APPEND, make sure the file + // size is updated. There is a possible race here if size is modified + // externally after metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - } - if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if fd.handle.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(fd.handle.fd)) - ctx.UninterruptibleSleepFinish(false) - return err + return fd.sync(ctx, false /* forFilesystemSync */) +} + +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { + err := func() error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) + }() + if err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (fd represents a regular file that was opened for writing). + if fd.isRegularFile && fd.vfsfd.IsWritable() { + return err + } + ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err) } - return fd.handle.file.fsync(ctx) + return nil } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7e825caae..9cbe805b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,7 +17,6 @@ package gofer import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 56bcf9bdb..4ae9d6d5e 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "inode_refs.go", package = "host", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "connected_endpoint_refs.go", package = "host", prefix = "ConnectedEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ConnectedEndpoint", }, @@ -33,7 +33,7 @@ go_library( "host.go", "inode_refs.go", "ioctl_unsafe.go", - "mmap.go", + "save_restore.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -51,6 +51,7 @@ go_library( "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index 0135e4428..13ef48cb5 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription { } // Create the file backed by hostFD. - file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */) + file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{}) if err != nil { ctx.Warningf("Error creating file from host FD: %v", err) break diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 698e913fe..39b902a3e 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -19,6 +19,7 @@ package host import ( "fmt" "math" + "sync/atomic" "syscall" "golang.org/x/sys/unix" @@ -40,34 +41,97 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + kernfs.InodeNoStatFS + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.CachedMappable + kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. + + locks vfs.FileLocks + + // When the reference count reaches zero, the host fd is closed. + inodeRefs + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // ino is an inode number unique within this filesystem. + // + // This field is initialized at creation time and is immutable. + ino uint64 + + // ftype is the file's type (a linux.S_IFMT mask). + // + // This field is initialized at creation time and is immutable. + ftype uint16 + + // mayBlock is true if hostFD is non-blocking, and operations on it may + // return EAGAIN or EWOULDBLOCK instead of blocking. + // + // This field is initialized at creation time and is immutable. + mayBlock bool + + // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file + // offsets are meaningful iff seekable is true. + // + // This field is initialized at creation time and is immutable. + seekable bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // savable is true if hostFD may be saved/restored by its numeric value. + // + // This field is initialized at creation time and is immutable. + savable bool + + // Event queue for blocking operations. + queue waiter.Queue + + // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data + // read from the pipe from previous calls to inode.beforeSave(). haveBuf + // and buf are protected by bufMu. haveBuf is accessed using atomic memory + // operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte +} + +func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) seekable := err != syserror.ESPIPE + // We expect regular files to be seekable, as this is required for them to + // be memory-mappable. + if !seekable && fileType == syscall.S_IFREG { + ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD) + return nil, syserror.ESPIPE + } i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: isTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, - } - i.pf.inode = i + hostFD: hostFD, + ino: fs.NextIno(), + ftype: uint16(fileType), + mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR, + seekable: seekable, + isTTY: isTTY, + savable: savable, + } + i.CachedMappable.Init(hostFD) i.EnableLeakCheck() - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { + // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and + // handle blocking behavior in the sentry. + if i.mayBlock { if err := syscall.SetNonblock(i.hostFD, true); err != nil { return nil, err } @@ -80,6 +144,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( // NewFDOptions contains options to NewFD. type NewFDOptions struct { + // If Savable is true, the host file descriptor may be saved/restored by + // numeric value; the sandbox API requires a corresponding host FD with the + // same numeric value to be provieded at time of restore. + Savable bool + // If IsTTY is true, the file descriptor is a TTY. IsTTY bool @@ -114,7 +183,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) } d := &kernfs.Dentry{} - i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY) if err != nil { return nil, err } @@ -132,7 +201,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // ImportFD sets up and returns a vfs.FileDescription from a donated fd. func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) { return NewFD(ctx, mnt, hostFD, &NewFDOptions{ - IsTTY: isTTY, + Savable: true, + IsTTY: isTTY, }) } @@ -191,68 +261,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } -// inode implements kernfs.Inode. -// -// +stateify savable -type inode struct { - kernfs.InodeNoStatFS - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. - - locks vfs.FileLocks - - // When the reference count reaches zero, the host fd is closed. - inodeRefs - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // seekable is false if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - seekable bool - - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. - // - // This field is initialized at creation time and is immutable. - wouldBlock bool - - // Event queue for blocking operations. - queue waiter.Queue - - // canMap specifies whether we allow the file to be memory mapped. - // - // This field is initialized at creation time and is immutable. - canMap bool - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // If canMap is true, mappings tracks mappings of hostFD into - // memmap.MappingSpaces. - mappings memmap.MappingSet - - // pf implements platform.File for mappings of hostFD. - pf inodePlatformFile -} - // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t @@ -422,14 +430,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre oldpgend, _ := usermem.PageRoundUp(oldSize) newpgend, _ := usermem.PageRoundUp(s.Size) if oldpgend != newpgend { - i.mapsMu.Lock() - i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - i.mapsMu.Unlock() + i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend}) } } } @@ -448,7 +449,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre // DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { i.inodeRefs.DecRef(func() { - if i.wouldBlock { + if i.mayBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } if err := unix.Close(i.hostFD); err != nil { @@ -567,6 +568,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin // PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { return 0, syserror.ESPIPE @@ -577,19 +585,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { + bufN, err := i.readFromBuf(ctx, &dst) + if err != nil { + return bufN, err + } n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) + total := bufN + n if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read // rather than retrying until complete. - if n != 0 { + if total != 0 { err = nil } else { err = syserror.ErrWouldBlock } } - return n, err + return total, err } f.offsetMu.Lock() @@ -599,13 +619,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } -func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // Check that flags are supported. - // - // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. - if flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP +func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) { + if atomic.LoadUint32(&i.haveBuf) == 0 { + return 0, nil + } + i.bufMu.Lock() + defer i.bufMu.Unlock() + if len(i.buf) == 0 { + return 0, nil } + n, err := dst.CopyOut(ctx, i.buf) + *dst = dst.DropFirst(n) + i.buf = i.buf[n:] + if len(i.buf) == 0 { + atomic.StoreUint32(&i.haveBuf, 0) + i.buf = nil + } + return int64(n), err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := dst.CopyOutFrom(ctx, reader) hostfd.PutReadWriterAt(reader) @@ -735,31 +768,37 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i } // Sync implements vfs.FileDescriptionImpl.Sync. -func (f *fileDescription) Sync(context.Context) error { +func (f *fileDescription) Sync(ctx context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.inode.canMap { + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + if f.inode.ftype != syscall.S_IFREG { return syserror.ENODEV } i := f.inode - i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + i.CachedMappable.InitFileMapperOnce() return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { f.inode.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // EventUnregister implements waiter.Waitable.EventUnregister. func (f *fileDescription) EventUnregister(e *waiter.Entry) { f.inode.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // Readiness uses the poll() syscall to check the status of the underlying FD. diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go new file mode 100644 index 000000000..8800652a9 --- /dev/null +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -0,0 +1,70 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/usermem" +) + +// beforeSave is invoked by stateify. +func (i *inode) beforeSave() { + if !i.savable { + panic("host.inode is not savable") + } + if i.ftype == syscall.S_IFIFO { + // If this pipe FD is readable, drain it so that bytes in the pipe can + // be read after restore. (This is a legacy VFS1 feature.) We don't + // know if the pipe FD is readable, so just try reading and tolerate + // EBADF from the read. + i.bufMu.Lock() + defer i.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) + if n != 0 { + i.buf = append(i.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF { + break + } + panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err)) + } + } + if len(i.buf) != 0 { + atomic.StoreUint32(&i.haveBuf, 1) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inode) afterLoad() { + if i.mayBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err)) + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err)) + } + } +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 412bdb2eb..b2f43a119 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} } -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 858cc24ce..6dbc7e34d 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "kernfs", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Dentry", + "Linker": "*Dentry", + }, +) + +go_template_instance( name = "fstree", out = "fstree.go", package = "kernfs", @@ -27,22 +39,11 @@ go_template_instance( ) go_template_instance( - name = "dentry_refs", - out = "dentry_refs.go", - package = "kernfs", - prefix = "Dentry", - template = "//pkg/refs_vfs2:refs_template", - types = { - "T": "Dentry", - }, -) - -go_template_instance( name = "static_directory_refs", out = "static_directory_refs.go", package = "kernfs", prefix = "StaticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "StaticDirectory", }, @@ -53,7 +54,7 @@ go_template_instance( out = "dir_refs.go", package = "kernfs_test", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -64,7 +65,7 @@ go_template_instance( out = "readonly_dir_refs.go", package = "kernfs_test", prefix = "readonlyDir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "readonlyDir", }, @@ -75,7 +76,7 @@ go_template_instance( out = "synthetic_directory_refs.go", package = "kernfs", prefix = "syntheticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "syntheticDirectory", }, @@ -84,13 +85,15 @@ go_template_instance( go_library( name = "kernfs", srcs = [ - "dentry_refs.go", + "dentry_list.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", "fstree.go", "inode_impl_util.go", "kernfs.go", + "mmap_util.go", + "save_restore.go", "slot_list.go", "static_directory_refs.go", "symlink.go", @@ -104,8 +107,12 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/safemem", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", @@ -129,6 +136,7 @@ go_test( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index b929118b1..485504995 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -47,11 +47,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index abf1905d6..f8dae22f8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() } +func (fd *GenericDirectoryFD) dentry() *Dentry { + return fd.vfsfd.Dentry().Impl().(*Dentry) +} + func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return fd.dentry().inode } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds @@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { - vfsd := fd.vfsfd.VirtualDentry().Dentry() - parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode + parentInode := genericParentOrSelf(fd.dentry()).inode stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err @@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent var err error relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset) return err } @@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fd.filesystem(), creds, opts) + return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts) } // Allocate implements vfs.FileDescriptionImpl.Allocate. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 6426a55f6..e77523f22 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { - if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return "", err +func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST + if name == "." || name == ".." { + return syserror.EEXIST } - if len(pc) > linux.NAME_MAX { - return "", syserror.ENAMETOOLONG + if len(name) > linux.NAME_MAX { + return syserror.ENAMETOOLONG } - if _, ok := parent.children[pc]; ok { - return "", syserror.EEXIST + if _, ok := parent.children[name]; ok { + return syserror.EEXIST } if parent.VFSDentry().IsDead() { - return "", syserror.ENOENT + return syserror.ENOENT } - return pc, nil + return nil } // checkDeleteLocked checks that the file represented by vfsd may be deleted. @@ -245,7 +244,41 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release(context.Context) { +func (fs *Filesystem) Release(ctx context.Context) { + root := fs.root + if root == nil { + return + } + fs.mu.Lock() + root.releaseKeptDentriesLocked(ctx) + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } + fs.mu.Unlock() + // Drop ref acquired in Dentry.InitRoot(). + root.DecRef(ctx) +} + +// releaseKeptDentriesLocked recursively drops all dentry references created by +// Lookup when Dentry.inode.Keep() is true. +// +// Precondition: Filesystem.mu is held. +func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) { + if d.inode.Keep() && d != d.fs.root { + d.decRefLocked(ctx) + } + + if d.isDir() { + var children []*Dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + child.releaseKeptDentriesLocked(ctx) + } + } } // Sync implements vfs.FilesystemImpl.Sync. @@ -318,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -360,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -373,7 +409,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } - childI = newSyntheticDirectory(rp.Credentials(), opts.Mode) + childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) } var child Dentry child.Init(fs, childI) @@ -396,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } @@ -517,9 +556,6 @@ afterTrailingSymlink: } var child Dentry child.Init(fs, childI) - // FIXME(gvisor.dev/issue/1193): Race between checking existence with - // fs.stepExistingLocked and parent.insertChild. If possible, we should hold - // dirMu from one to the other. parent.insertChild(pc, &child) // Open may block so we need to unlock fs.mu. IncRef child to prevent // its destruction while fs.mu is unlocked. @@ -626,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDir) - switch err { + pc := rp.Component() + switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err { case nil: // Ok, continue with rename as replacement. case syserror.EEXIST: @@ -791,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 122b10591..d83c17f83 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -21,9 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) } // IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } @@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // +stateify savable type InodeAttrs struct { - devMajor uint32 - devMinor uint32 - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 + blockSize uint32 + + // Timestamps, all nsecs from the Unix epoch. + atime int64 + mtime int64 + ctime int64 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) + atomic.StoreUint32(&a.blockSize, usermem.PageSize) + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.atime, now) + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) } // DevMajor returns the device major number. @@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// TouchAtime updates a.atime to the current time. +func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds()) + mnt.EndWrite() +} + +// TouchCMtime updates a.{c/m}time to the current time. The caller should +// synchronize calls to this so that ctime and mtime are updated to the same +// value. +func (a *InodeAttrs) TouchCMtime(ctx context.Context) { + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) +} + // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME stat.DevMajor = a.devMajor stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) @@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li stat.UID = atomic.LoadUint32(&a.uid) stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - + stat.Blksize = atomic.LoadUint32(&a.blockSize) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime)) return stat, nil } // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - return a.SetInodeStat(ctx, fs, creds, opts) -} - -// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. -// This function can be used by other kernfs-based filesystem implementation to -// sets the unexported attributes into InodeAttrs. -func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } @@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds // inode numbers are immutable after node creation. Setting the size is often // allowed by kernfs files but does not do anything. If some other behavior is // needed, the embedder should consider extending SetStat. - // - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { @@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds atomic.StoreUint32(&a.gid, stat.GID) } + now := ktime.NowFromContext(ctx).Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.atime, stat.Atime.ToNsec()) + } + if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec()) + } + return nil } @@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error } // IterDirents implements Inode.IterDirents. -func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { // All entries from OrderedChildren have already been handled in // GenericDirectoryFD.IterDirents. return offset, nil @@ -528,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e return o.Unlink(ctx, name, child) } -// +stateify savable -type renameAcrossDifferentImplementationsError struct{} - -func (renameAcrossDifferentImplementationsError) Error() string { - return "rename across inodes with different implementations" -} - // Rename implements Inode.Rename. // // Precondition: Rename may only be called across two directory inodes with @@ -545,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string { // // Postcondition: reference on any replaced dentry transferred to caller. func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error { + if !o.writable { + return syserror.EPERM + } + dst, ok := dstDir.(interface{}).(*OrderedChildren) if !ok { - return renameAcrossDifferentImplementationsError{} + return syserror.EXDEV } - if !o.writable || !dst.writable { + if !dst.writable { return syserror.EPERM } + // Note: There's a potential deadlock below if concurrent calls to Rename // refer to the same src and dst directories in reverse. We avoid any // ordering issues because the caller is required to serialize concurrent @@ -619,9 +657,9 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { +func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) inode.EnableLeakCheck() inode.OrderedChildren.Init(OrderedChildrenOptions{}) @@ -632,12 +670,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { +func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } s.fdOpts = fdOpts - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 606081e68..abb477c7d 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -107,6 +108,23 @@ type Filesystem struct { // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. nextInoMinusOne uint64 + + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by mu. + cachedDentries dentryList + cachedDentriesLen uint64 + + // MaxCachedDentries is the maximum size of cachedDentries. If not set, + // defaults to 0 and kernfs does not cache any dentries. This is immutable. + MaxCachedDentries uint64 + + // root is the root dentry of this filesystem. Note that root may be nil for + // filesystems on a disconnected mount without a root (e.g. pipefs, sockfs, + // hostfs). Filesystem holds an extra reference on root to prevent it from + // being destroyed prematurely. This is immutable. + root *Dentry } // deferDecRef defers dropping a dentry ref until the next call to @@ -165,7 +183,12 @@ const ( // +stateify savable type Dentry struct { vfsd vfs.Dentry - DentryRefs + + // refs is the reference count. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has already + // been destroyed. refs are allowed to go to 0 and increase again. refs is + // accessed using atomic memory operations. + refs int64 // fs is the owning filesystem. fs is immutable. fs *Filesystem @@ -177,6 +200,12 @@ type Dentry struct { parent *Dentry name string + // If cached is true, dentryEntry links dentry into + // Filesystem.cachedDentries. cached and dentryEntry are protected by + // Filesystem.mu. + cached bool + dentryEntry + // dirMu protects children and the names of child Dentries. // // Note that holding fs.mu for writing is not sufficient; @@ -188,6 +217,201 @@ type Dentry struct { inode Inode } +// IncRef implements vfs.DentryImpl.IncRef. +func (d *Dentry) IncRef() { + // d.refs may be 0 if d.fs.mu is locked, which serializes against + // d.cacheLocked(). + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *Dentry) TryIncRef() bool { + for { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.fs.mu.Lock() + d.cacheLocked(ctx) + d.fs.mu.Unlock() + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +func (d *Dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.cacheLocked(ctx) + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +// cacheLocked should be called after d's reference count becomes 0. The ref +// count check may happen before acquiring d.fs.mu so there might be a race +// condition where the ref count is increased again by the time the caller +// acquires d.fs.mu. This race is handled. +// Only reachable dentries are added to the cache. However, a dentry might +// become unreachable *while* it is in the cache due to invalidation. +// +// Preconditions: d.fs.mu must be locked for writing. +func (d *Dentry) cacheLocked(ctx context.Context) { + // Dentries with a non-zero reference count must be retained. (The only way + // to obtain a reference on a dentry with zero references is via path + // resolution, which requires d.fs.mu, so if d.refs is zero then it will + // remain zero while we hold d.fs.mu for writing.) + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d)) + } + if refs > 0 { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + return + } + // If the dentry is deleted and invalidated or has no parent, then it is no + // longer reachable by path resolution and should be dropped immediately + // because it has zero references. + // Note that a dentry may not always have a parent; for example magic links + // as described in Inode.Getlink. + if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil { + if !isDead { + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) + } + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + d.destroyLocked(ctx) + return + } + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries { + return + } + d.fs.evictCachedDentryLocked(ctx) + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.opts.maxCachedDentries, so we don't loop. +} + +// Preconditions: +// * fs.mu must be locked for writing. +// * fs.cachedDentriesLen != 0. +func (fs *Filesystem) evictCachedDentryLocked(ctx context.Context) { + // Evict the least recently used dentry because cache size is greater than + // max cache size (configured on mount). + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // after it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if !victim.vfsd.IsDead() { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry()) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.MaxCachedDentries, so we don't loop. +} + +// destroyLocked destroys the dentry. +// +// Preconditions: +// * d.fs.mu must be locked for writing. +// * d.refs == 0. +// * d should have been removed from d.parent.children, i.e. d is not reachable +// by path traversal. +// * d.vfsd.IsDead() is true. +func (d *Dentry) destroyLocked(ctx context.Context) { + refs := atomic.LoadInt64(&d.refs) + switch refs { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + + if d.parent != nil { + d.parent.decRefLocked(ctx) + } + + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *Dentry) RefType() string { + return "kernfs.Dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *Dentry) LeakMessage() string { + return fmt.Sprintf("[kernfs.Dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *Dentry) LogRefs() bool { + return false +} + +// InitRoot initializes this dentry as the root of the filesystem. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) InitRoot(fs *Filesystem, inode Inode) { + d.Init(fs, inode) + fs.root = d + // Hold an extra reference on the root dentry. It is held by fs to prevent the + // root from being "cached" and subsequently evicted. + d.IncRef() +} + // Init initializes this dentry. // // Precondition: Caller must hold a reference on inode. @@ -197,6 +421,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { d.vfsd.Init(d) d.fs = fs d.inode = inode + atomic.StoreInt64(&d.refs, 1) ftype := inode.Mode().FileType() if ftype == linux.ModeDirectory { d.flags |= dflagsIsDir @@ -204,7 +429,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } - d.EnableLeakCheck() + refsvfs2.Register(d) } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -222,32 +447,6 @@ func (d *Dentry) isSymlink() bool { return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 } -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(ctx context.Context) { - decRefParent := false - d.fs.mu.Lock() - d.DentryRefs.DecRef(func() { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - // We will DecRef d.parent once all locks are dropped. - decRefParent = true - d.parent.dirMu.Lock() - // Remove d from parent.children. It might already have been - // removed due to invalidation. - if _, ok := d.parent.children[d.name]; ok { - delete(d.parent.children, d.name) - d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) - } - d.parent.dirMu.Unlock() - } - }) - d.fs.mu.Unlock() - if decRefParent { - d.parent.DecRef(ctx) // IncRef from Dentry.insertChild. - } -} - // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // Although Linux technically supports inotify on pseudo filesystems (inotify @@ -267,7 +466,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // this dentry. This does not update the directory inode, so calling this on its // own isn't sufficient to insert a child into a directory. // -// Precondition: d must represent a directory inode. +// Preconditions: +// * d must represent a directory inode. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChild(name string, child *Dentry) { d.dirMu.Lock() d.insertChildLocked(name, child) @@ -280,6 +481,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) { // Preconditions: // * d must represent a directory inode. // * d.dirMu must be locked. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d)) @@ -436,7 +638,7 @@ type inodeDirectory interface { // the inode is a directory. // // The child returned by Lookup will be hashed into the VFS dentry tree, - // atleast for the duration of the current FS operation. + // at least for the duration of the current FS operation. // // Lookup must return the child with an extra reference whose ownership is // transferred to the dentry that is created to point to that inode. If @@ -454,7 +656,7 @@ type inodeDirectory interface { // inside the entries returned by this IterDirents invocation. In other words, // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as // the return value, while 'relOffset' is the place to start iteration. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) + IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 82fa19c03..2418eec44 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file." // RootDentryFn is a generator function for creating the root dentry of a test // filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode +type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode // newTestSystem sets up a minimal environment for running a test, including an // instance of a test filesystem. Tests can control the contents of the @@ -72,10 +72,10 @@ type file struct { content string } -func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode { +func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) return f } @@ -105,9 +105,9 @@ type readonlyDir struct { locks vfs.FileLocks } -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &readonlyDir{} - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) dir.EnableLeakCheck() dir.IncLinks(dir.OrderedChildren.Populate(contents)) @@ -142,10 +142,10 @@ type dir struct { fs *filesystem } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) dir.EnableLeakCheck() @@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) { func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) + dir := d.fs.newDir(ctx, creds, opts.Mode, nil) if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) d.IncLinks(1) return dir, nil } func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") + f := d.fs.newFile(ctx, creds, "") if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) return f, nil } @@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {} func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { fs := &filesystem{} fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(creds, fs) + root := fst.rootFn(ctx, creds, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, root) return fs.VFSFilesystem(), d.VFSDentry(), nil @@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst // -------------------- Remainder of the file are test cases -------------------- func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -228,9 +230,9 @@ func TestBasic(t *testing.T) { } func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) { } func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) { } func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, nil) + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, nil) }) defer sys.Destroy() @@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) { } func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(creds, 0755, nil), + "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir3": fs.newDir(ctx, creds, 0755, nil), }), - "file1": fs.newFile(creds, staticFileContent), + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go index b51a17bed..bd6a134b4 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package host +package kernfs import ( "gvisor.dev/gvisor/pkg/context" @@ -26,11 +26,14 @@ import ( // inodePlatformFile implements memmap.File. It exists solely because inode // cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // -// inodePlatformFile should only be used if inode.canMap is true. -// // +stateify savable type inodePlatformFile struct { - *inode + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + // inodePlatformFile does not own hostFD and hence should not close it. + hostFD int // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex `state:"nosave"` @@ -43,12 +46,12 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + fileMapperInitOnce sync.Once `state:"nosave"` } +var _ memmap.File = (*inodePlatformFile)(nil) + // IncRef implements memmap.File.IncRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) @@ -56,8 +59,6 @@ func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { } // DecRef implements memmap.File.DecRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) @@ -65,8 +66,6 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } @@ -76,10 +75,32 @@ func (i *inodePlatformFile) FD() int { return i.hostFD } -// AddMapping implements memmap.Mappable.AddMapping. +// CachedMappable implements memmap.Mappable. This utility can be embedded in a +// kernfs.Inode that represents a host file to make the inode mappable. +// CachedMappable caches the mappings of the host file. CachedMappable must be +// initialized (via Init) with a hostFD before use. // -// Precondition: i.inode.canMap must be true. -func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +// +stateify savable +type CachedMappable struct { + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // mappings tracks mappings of hostFD into memmap.MappingSpaces. + mappings memmap.MappingSet + + // pf implements memmap.File for mappings backed by a host fd. + pf inodePlatformFile +} + +var _ memmap.Mappable = (*CachedMappable)(nil) + +// Init initializes i.pf. This must be called before using CachedMappable. +func (i *CachedMappable) Init(hostFD int) { + i.pf.hostFD = hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { i.mapsMu.Lock() mapped := i.mappings.AddMapping(ms, ar, offset, writable) for _, r := range mapped { @@ -90,9 +111,7 @@ func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar userm } // RemoveMapping implements memmap.Mappable.RemoveMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { i.mapsMu.Lock() unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) for _, r := range unmapped { @@ -102,16 +121,12 @@ func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar us } // CopyMapping implements memmap.Mappable.CopyMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { return i.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { mr := optional return []memmap.Translation{ { @@ -124,10 +139,26 @@ func (i *inode) Translate(ctx context.Context, required, optional memmap.Mappabl } // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) InvalidateUnsavable(ctx context.Context) error { +func (i *CachedMappable) InvalidateUnsavable(ctx context.Context) error { // We expect the same host fd across save/restore, so all translations // should be valid. return nil } + +// InvalidateRange invalidates the passed range on i.mappings. +func (i *CachedMappable) InvalidateRange(r memmap.MappableRange) { + i.mapsMu.Lock() + i.mappings.Invalidate(r, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() +} + +// InitFileMapperOnce initializes the host file mapper. It ensures that the +// file mapper is initialized just once. +func (i *CachedMappable) InitFileMapperOnce() { + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) +} diff --git a/pkg/sentry/fsimpl/kernfs/save_restore.go b/pkg/sentry/fsimpl/kernfs/save_restore.go new file mode 100644 index 000000000..f78509eb7 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/save_restore.go @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// afterLoad is invoked by stateify. +func (d *Dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) >= 0 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (i *inodePlatformFile) afterLoad() { + if i.fileMapper.IsInited() { + // Ensure that we don't call i.fileMapper.Init() again. + i.fileMapperInitOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 934cc6c9e..a0736c0d6 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -38,16 +38,16 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { +func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { inode := &StaticSymlink{} - inode.Init(creds, devMajor, devMinor, ino, target) + inode.Init(ctx, creds, devMajor, devMinor, ino, target) return inode } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { +func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode.Readlink. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index d0ed17b18..463d77d79 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -41,17 +41,17 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) -func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode { +func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { inode := &syntheticDirectory{} - inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) return inode } -func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) @@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs if !opts.ForSyntheticMountpoint { return nil, syserror.EPERM } - subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) if err := dir.OrderedChildren.Insert(name, subdirI); err != nil { subdirI.DecRef(ctx) return nil, err } + dir.TouchCMtime(ctx) return subdirI, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 1e11b0428..bf13bbbf4 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -23,6 +23,7 @@ go_library( "fstree.go", "overlay.go", "regular_file.go", + "save_restore.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -30,6 +31,8 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 4506642ca..469f3a33d 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er if dirent.Name == "." || dirent.Name == ".." { continue } - child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 78a01bbb7..bc07d72c0 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, lookupLayerNone, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, lookupLayerNone, err } afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } rp.Advance() - return d.parent, nil + return d.parent, d.parent.topLookupLayer(), nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err + return nil, topLookupLayer, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx) if err != nil { - return nil, err + return nil, lookupLayerNone, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, topLookupLayer, err } goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) { if child, ok := parent.children[name]; ok { - return child, nil + return child, child.topLookupLayer(), nil } - child, err := fs.lookupLocked(ctx, parent, name) + child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name) if err != nil { - return nil, err + return nil, topLookupLayer, err } if parent.children == nil { parent.children = make(map[string]*dentry) @@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s parent.children[name] = child // child's refcount is initially 0, so it may be dropped after traversal. *ds = appendDentry(*ds, child) - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. -func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { +func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) { childPath := fspath.Parse(name) child := fs.newDentry() - existsOnAnyLayer := false + topLookupLayer := lookupLayerNone var lookupErr error vfsObj := fs.vfsfs.VirtualFilesystem() @@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) - if !existsOnAnyLayer { + if topLookupLayer == lookupLayerNone { // Mode, UID, GID, and (for non-directories) inode number come from // the topmost layer on which the file exists. mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if isWhiteout(&stat) { // This is a whiteout, so it "doesn't exist" on this layer, and // layers below this one are ignored. + if isUpper { + topLookupLayer = lookupLayerUpperWhiteout + } return false } isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR - if existsOnAnyLayer && !isDir { + if topLookupLayer != lookupLayerNone && !isDir { // Directories are not merged with non-directory files from lower // layers; instead, layers including and below the first // non-directory file are ignored. (This file must be a directory @@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } else { child.lowerVDs = append(child.lowerVDs, childVD) } - if !existsOnAnyLayer { - existsOnAnyLayer = true + if topLookupLayer == lookupLayerNone { + if isUpper { + topLookupLayer = lookupLayerUpper + } else { + topLookupLayer = lookupLayerLower + } child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID @@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if lookupErr != nil { child.destroyLocked(ctx) - return nil, lookupErr + return nil, topLookupLayer, lookupErr } - if !existsOnAnyLayer { + if !topLookupLayer.existsInOverlay() { child.destroyLocked(ctx) - return nil, syserror.ENOENT + return nil, topLookupLayer, syserror.ENOENT } // Device and inode numbers were copied from the topmost layer above; @@ -302,14 +310,20 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str child.devMinor = fs.dirDevMinor child.ino = fs.newDirIno() } else if !child.upperVD.Ok() { + childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor) + if err != nil { + ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) + child.destroyLocked(ctx) + return nil, topLookupLayer, err + } child.devMajor = linux.UNNAMED_MAJOR - child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()] + child.devMinor = childDevMinor } parent.IncRef() child.parent = parent child.name = name - return child, nil + return child, topLookupLayer, nil } // lookupLayerLocked is similar to lookupLocked, but only returns information @@ -408,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool { func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -428,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, d := rp.Start().Impl().(*dentry) for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -463,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if name == "." || name == ".." { return syserror.EEXIST } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.vfsd.IsDead() { return syserror.ENOENT } @@ -489,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } + // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { @@ -791,9 +806,9 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { - fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds) + fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err } @@ -893,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * // Preconditions: // * parent.dirMu must be locked. // * parent does not already contain a child named rp.Component(). -func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { +func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { return nil, err @@ -918,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving Start: parent.upperVD, Path: fspath.Parse(childName), } - // We don't know if a whiteout exists on the upper layer; speculatively - // unlink it. - // - // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do - // know whether a whiteout exists. - var haveUpperWhiteout bool - switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err { - case nil: - haveUpperWhiteout = true - case syserror.ENOENT: - haveUpperWhiteout = false - default: - return nil, err + // Unlink the whiteout if it exists. + if haveUpperWhiteout { + if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil { + log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err) + return nil, err + } } // Create the file on the upper layer, and get an FD representing it. upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{ @@ -961,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Re-lookup to get a dentry representing the new file, which is needed for // the returned FD. - child, err := fs.getChildLocked(ctx, parent, childName, ds) + child, _, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) @@ -970,7 +978,10 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - // Finally construct the overlay FD. + // Finally construct the overlay FD. Below this point, we don't perform + // cleanup (the file was created successfully even if we can no longer open + // it for some reason). + parent.dirents = nil upperFlags := upperFD.StatusFlags() fd := ®ularFileFD{ copiedUp: true, @@ -981,8 +992,6 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving upperFDOpts := upperFD.Options() if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil { upperFD.DecRef(ctx) - // Don't bother with cleanup; the file was created successfully, we - // just can't open it anymore for some reason. return nil, err } parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */) @@ -1040,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } @@ -1072,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.vfsd.IsDead() { return syserror.ENOENT } - replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) - if err != nil { - return err - } var ( - replaced *dentry - replacedVFSD *vfs.Dentry - whiteouts map[string]bool + replaced *dentry + replacedVFSD *vfs.Dentry + replacedLayer lookupLayer + whiteouts map[string]bool ) - if replacedLayer.existsInOverlay() { - replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil { - return err - } + replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { + return err + } + if replaced != nil { replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { @@ -1289,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Unlike UnlinkAt, we need a dentry representing the child directory being // removed in order to verify that it's empty. - child, err := fs.getChildLocked(ctx, parent, name, &ds) + child, _, err := fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } @@ -1541,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if parentMode&linux.S_ISVTX != 0 { // If the parent's sticky bit is set, we need a child dentry to get // its owner. - child, err = fs.getChildLocked(ctx, parent, name, &ds) + child, _, err = fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 4c5de8d32..73130bc8d 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,7 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// filesystem.devMu // *** "memmap.Mappable locks" below this point // dentry.mapsMu // *** "memmap.Mappable locks taken by Translate" below this point @@ -33,12 +34,14 @@ package overlay import ( + "fmt" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -99,10 +102,15 @@ type filesystem struct { // is immutable. dirDevMinor uint32 - // lowerDevMinors maps lower layer filesystems to device minor numbers - // assigned to non-directory files originating from that filesystem. - // lowerDevMinors is immutable. - lowerDevMinors map[*vfs.Filesystem]uint32 + // lowerDevMinors maps device numbers from lower layer filesystems to + // device minor numbers assigned to non-directory files originating from + // that filesystem. (This remapping is necessary for lower layers because a + // file on a lower layer, and that same file on an overlay, are + // distinguishable because they will diverge after copy-up; this isn't true + // for non-directory files already on the upper layer.) lowerDevMinors is + // protected by devMu. + devMu sync.Mutex `state:"nosave"` + lowerDevMinors map[layerDevNumber]uint32 // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different @@ -114,78 +122,69 @@ type filesystem struct { lastDirIno uint64 } +// +stateify savable +type layerDevNumber struct { + major uint32 + minor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { mopts := vfs.GenericParseMountOptions(opts.Data) fsoptsRaw := opts.InternalData - fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) - if fsoptsRaw != nil && !haveFSOpts { + fsopts, ok := fsoptsRaw.(FilesystemOptions) + if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } - if haveFSOpts { - if len(fsopts.LowerRoots) == 0 { - ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + + if upperPathname, ok := mopts["upperdir"]; ok { + if fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") return nil, nil, syserror.EINVAL } - if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + delete(mopts, "upperdir") + // Linux overlayfs also requires a workdir when upperdir is + // specified; we don't, so silently ignore this option. + delete(mopts, "workdir") + upperPath := fspath.Parse(upperPathname) + if !upperPath.Absolute { + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } - // We don't enforce a maximum number of lower layers when not - // configured by applications; the sandbox owner can have an overlay - // filesystem with any number of lower layers. - } else { - vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef(ctx) - upperPathname, ok := mopts["upperdir"] - if ok { - delete(mopts, "upperdir") - // Linux overlayfs also requires a workdir when upperdir is - // specified; we don't, so silently ignore this option. - delete(mopts, "workdir") - upperPath := fspath.Parse(upperPathname) - if !upperPath.Absolute { - ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL - } - upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ - Root: vfsroot, - Start: vfsroot, - Path: upperPath, - FollowFinalSymlink: true, - }, &vfs.GetDentryOptions{ - CheckSearchable: true, - }) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer upperRoot.DecRef(ctx) - privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer privateUpperRoot.DecRef(ctx) - fsopts.UpperRoot = privateUpperRoot + upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: upperPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + return nil, nil, err + } + privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) + upperRoot.DecRef(ctx) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + return nil, nil, err } - lowerPathnamesStr, ok := mopts["lowerdir"] - if !ok { - ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + defer privateUpperRoot.DecRef(ctx) + fsopts.UpperRoot = privateUpperRoot + } + + if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { + if len(fsopts.LowerRoots) != 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") - const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK - if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") - return nil, nil, syserror.EINVAL - } - if len(lowerPathnames) > maxLowerLayers { - ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) - return nil, nil, syserror.EINVAL - } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { @@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) + lowerRoot.DecRef(ctx) if err != nil { ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err @@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } + if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } - // Allocate device numbers. + if len(fsopts.LowerRoots) == 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") + return nil, nil, syserror.EINVAL + } + if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") + return nil, nil, syserror.EINVAL + } + const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK + if len(fsopts.LowerRoots) > maxLowerLayers { + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) + return nil, nil, syserror.EINVAL + } + + // Allocate dirDevMinor. lowerDevMinors are allocated dynamically. dirDevMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } - lowerDevMinors := make(map[*vfs.Filesystem]uint32) - for _, lowerRoot := range fsopts.LowerRoots { - lowerFS := lowerRoot.Mount().Filesystem() - if _, ok := lowerDevMinors[lowerFS]; !ok { - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - vfsObj.PutAnonBlockDevMinor(dirDevMinor) - for _, lowerDevMinor := range lowerDevMinors { - vfsObj.PutAnonBlockDevMinor(lowerDevMinor) - } - return nil, nil, err - } - lowerDevMinors[lowerFS] = devMinor - } - } // Take extra references held by the filesystem. if fsopts.UpperRoot.Ok() { @@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt opts: fsopts, creds: creds.Fork(), dirDevMinor: dirDevMinor, - lowerDevMinors: lowerDevMinors, + lowerDevMinors: make(map[layerDevNumber]uint32), } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root.ino = fs.newDirIno() } else if !root.upperVD.Ok() { root.devMajor = linux.UNNAMED_MAJOR - root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()] + rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err) + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + root.devMinor = rootDevMinor root.ino = rootStat.Ino } else { root.devMajor = rootStat.DevMajor @@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 { return atomic.AddUint64(&fs.lastDirIno, 1) } +func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) { + fs.devMu.Lock() + defer fs.devMu.Unlock() + orig := layerDevNumber{layerMajor, layerMinor} + if minor, ok := fs.lowerDevMinors[orig]; ok { + return minor, nil + } + minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor() + if err != nil { + return 0, err + } + fs.lowerDevMinors[orig] = minor + return minor, nil +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -458,9 +479,9 @@ type dentry struct { // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is // accessed using atomic memory operations. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` lowerMappings memmap.MappingSet - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` wrappedMappable memmap.Mappable isMappable uint32 @@ -484,6 +505,7 @@ func (fs *filesystem) newDentry() *dentry { } d.lowerVDs = d.inlineLowerVDs[:0] d.vfsd.Init(d) + refsvfs2.Register(d) return d } @@ -491,17 +513,19 @@ func (fs *filesystem) newDentry() *dentry { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkDropLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -509,15 +533,27 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("overlay.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("overlay.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. // @@ -577,12 +613,27 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.parent.dirMu.Unlock() // Drop the reference held by d on its parent without recursively // locking d.fs.renameMu. - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("overlay.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "overlay.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -645,6 +696,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry { return vd } +func (d *dentry) topLookupLayer() lookupLayer { + if d.upperVD.Ok() { + return lookupLayerUpper + } + return lookupLayerLower +} + func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go new file mode 100644 index 000000000..54809f16c --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package overlay + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e44b79b68..0ecb592cf 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -101,7 +101,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem) *inode { creds := auth.CredentialsFromContext(ctx) return &inode{ - pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize), ino: fs.Filesystem.NextIno(), uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2e086e34c..5196a2a80 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "fd_dir_inode_refs.go", package = "proc", prefix = "fdDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdDirInode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "fd_info_dir_inode_refs.go", package = "proc", prefix = "fdInfoDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdInfoDirInode", }, @@ -30,7 +30,7 @@ go_template_instance( out = "subtasks_inode_refs.go", package = "proc", prefix = "subtasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "subtasksInode", }, @@ -41,7 +41,7 @@ go_template_instance( out = "task_inode_refs.go", package = "proc", prefix = "taskInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "taskInode", }, @@ -52,7 +52,7 @@ go_template_instance( out = "tasks_inode_refs.go", package = "proc", prefix = "tasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tasksInode", }, @@ -82,6 +82,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index fd70a07de..8716d0a3c 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -17,6 +17,7 @@ package proc import ( "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,10 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "proc" +const ( + // Name is the default filesystem name. + Name = "proc" + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType is the factory class for procfs. // @@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if err != nil { return nil, nil, err } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + procfs := &filesystem{ devMinor: devMinor, } + procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -74,9 +92,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - inode := procfs.newTasksInode(k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, cgroups) var dentry kernfs.Dentry - dentry.Init(&procfs.Filesystem, inode) + dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil } @@ -94,11 +112,11 @@ type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { - inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) +func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) return inode } @@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } -func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { - return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ +func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { + return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) } diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index bad2fab4f..cb3c5e0fd 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) subInode.EnableLeakCheck() @@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { return offset, syserror.ENOENT diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index b63a4eca0..19011b010 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(task, fs.NextIno(), 0400), "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), "net": fs.newTaskNetDir(task), @@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil) func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } @@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 2c80ac5c2..d268b44be 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Lookup implements kernfs.inodeDirectory.Lookup. @@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern task: task, fd: fd, } - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements Inode.IterDirents. -func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 79f8b7e9f..ba71d0fde 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -249,7 +250,7 @@ type commInode struct { func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in return int64(srclen), nil } +var _ kernfs.Inode = (*memInode)(nil) + +// memInode implements kernfs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + task *kernel.Task + locks vfs.FileLocks +} + +func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &memInode{task: task} + inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) +} + +// Open implements kernfs.Inode.Open. +func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, f.task, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(f.task); err != nil { + return nil, err + } + fd := &memFD{} + if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +var _ vfs.FileDescriptionImpl = (*memFD)(nil) + +// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem. +// +// +stateify savable +type memFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *memInode + + // mu guards the fields below. + mu sync.Mutex `state:"nosave"` + offset int64 +} + +// Init initializes memFD. +func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error { + fd.LockFD.Init(&inode.locks) + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return err + } + fd.inode = inode + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + case linux.SEEK_CUR: + offset += fd.offset + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.offset = offset + return offset, nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + m, err := getMMIncRef(fd.inode.task) + if err != nil { + return 0, nil + } + defer m.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.offset, opts) + fd.offset += n + fd.mu.Unlock() + return n, err +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *memFD) Release(context.Context) {} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil) func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil) func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode @@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // Create a synthetic inode to represent the namespace. fs := mnt.Filesystem().Impl().(*filesystem) + nsInode := &namespaceInode{} + nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444) dentry := &kernfs.Dentry{} - dentry.Init(&fs.Filesystem, &namespaceInode{}) + dentry.Init(&fs.Filesystem, nsInode) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1. mnt.IncRef() @@ -897,11 +1056,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 3425e8698..5a9ee111f 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(root, 0444, &netStatData{}), - "packet": fs.newInode(root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(task, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(task, root, 0444, &netStatData{}), + "packet": fs.newInode(task, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(root, 0444, newStaticFile(ptype)), - "route": fs.newInode(root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(task, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6)) } } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3259c3732..b81ea14bf 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -62,19 +62,19 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ - "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": fs.newInode(root, 0444, &filesystemsData{}), - "loadavg": fs.newInode(root, 0444, &loadavgData{}), - "sys": fs.newSysDir(root, k), - "meminfo": fs.newInode(root, 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), - "stat": fs.newInode(root, 0444, &statData{}), - "uptime": fs.newInode(root, 0444, &uptimeData{}), - "version": fs.newInode(root, 0444, &versionData{}), + "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), + "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), + "sys": fs.newSysDir(ctx, root, k), + "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newInode(ctx, root, 0444, &statData{}), + "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), + "version": fs.newInode(ctx, root, 0444, &versionData{}), } inode := &tasksInode{ @@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace fs: fs, cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) @@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err // If it failed to parse, check if it's one of the special handled files. switch name { case selfName: - return i.newSelfSymlink(root), nil + return i.newSelfSymlink(ctx, root), nil case threadSelfName: - return i.newThreadSelfSymlink(root), nil + return i.newThreadSelfSymlink(ctx, root), nil } return nil, syserror.ENOENT } @@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { +func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 07c27cdd9..01b7a6678 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -43,9 +43,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &selfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } @@ -84,9 +84,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &threadSelfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 95420368d..7c7afdcfa 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -40,93 +40,93 @@ const ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { - return fs.newStaticDir(root, map[string]kernfs.Inode{ - "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{ - "hostname": fs.newInode(root, 0444, &hostnameData{}), - "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { + return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), }), - "vm": fs.newStaticDir(root, map[string]kernfs.Inode{ - "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")), + "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), - "net": fs.newSysNetDir(root, k), + "net": fs.newSysNetDir(ctx, root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { +func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { var contents map[string]kernfs.Inode // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ - "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}), + "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")), - "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")), - "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")), - "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")), + "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), + "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")), - "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), - "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")), - "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")), - "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")), - "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")), - "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")), - "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")), + "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")), + "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")), }), - "core": fs.newStaticDir(root, map[string]kernfs.Inode{ - "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")), - "message_burst": fs.newInode(root, 0444, newStaticFile("10")), - "message_cost": fs.newInode(root, 0444, newStaticFile("5")), - "optmem_max": fs.newInode(root, 0444, newStaticFile("0")), - "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")), - "somaxconn": fs.newInode(root, 0444, newStaticFile("128")), - "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")), + "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")), + "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")), + "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), }), } } - return fs.newStaticDir(root, contents) + return fs.newStaticDir(ctx, root, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2582ababd..7ee6227a9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -77,6 +77,7 @@ var ( "gid_map": linux.DT_REG, "io": linux.DT_REG, "maps": linux.DT_REG, + "mem": linux.DT_REG, "mountinfo": linux.DT_REG, "mounts": linux.DT_REG, "net": linux.DT_DIR, diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index cf91ea36c..fda1fa942 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { +func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry { fs := mnt.Filesystem().Impl().(*filesystem) // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) + i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(&fs.Filesystem, i) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 906cd52cb..09043b572 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "dir_refs.go", package = "sys", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -28,6 +28,7 @@ go_library( "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 31a361029..b13f141a8 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -29,7 +29,7 @@ import ( func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode { k := &kcovInode{} - k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) return k } diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1ad679830..506a2a0f0 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -18,6 +18,7 @@ package sys import ( "bytes" "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -29,9 +30,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "sysfs" -const defaultSysDirMode = linux.FileMode(0755) +const ( + // Name is the default filesystem name. + Name = "sysfs" + defaultSysDirMode = linux.FileMode(0755) + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType implements vfs.FilesystemType. // @@ -62,31 +66,43 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + fs := &filesystem{ devMinor: devMinor, } + fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil), }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ + "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "cpu": cpuDir(ctx, fs, creds), }), }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), + "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), "kernel": kernelDir(ctx, fs, creds), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), + "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), }) var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) return fs.VFSFilesystem(), rootD.VFSDentry(), nil } @@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() children := map[string]kernfs.Inode{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), } for i := uint(0); i < maxCPUCores; i++ { - children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil) } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { @@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker var children map[string]kernfs.Inode if coverage.KcovAvailable() { children = map[string]kernfs.Inode{ - "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{ + "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ "kcov": fs.newKcovFile(ctx, creds), }), } } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } // Release implements vfs.FilesystemImpl.Release. @@ -140,9 +156,9 @@ type dir struct { locks vfs.FileLocks } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { d := &dir{} - d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) d.EnableLeakCheck() d.IncLinks(d.OrderedChildren.Populate(contents)) @@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { +func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) return c } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 5cd428d64..fe520b6fd 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -31,7 +31,7 @@ go_template_instance( out = "inode_refs.go", package = "tmpfs", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -48,6 +48,7 @@ go_library( "inode_refs.go", "named_pipe.go", "regular_file.go", + "save_restore.go", "socket_file.go", "symlink.go", "tmpfs.go", @@ -60,6 +61,7 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index d772db9e9..57e7b57b0 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -32,7 +31,7 @@ type namedPipe struct { // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} + file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index ce4e3eda7..98680fde9 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -42,7 +42,7 @@ type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile + memFile *pgalloc.MemoryFile `state:"nosave"` // memoryUsageKind is the memory accounting category under which pages backing // this regularFile's contents are accounted. @@ -92,7 +92,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, + memFile: fs.mfp.MemoryFile(), memoryUsageKind: usage.Tmpfs, seals: linux.F_SEAL_SEAL, } diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go new file mode 100644 index 000000000..b27f75cc2 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +// afterLoad is called by stateify. +func (rf *regularFile) afterLoad() { + rf.memFile = rf.inode.fs.mfp.MemoryFile() +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index e2a0aac69..4ce859d57 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -61,8 +61,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile + // mfp is used to allocate memory that stores regular file contents. mfp is + // immutable. + mfp pgalloc.MemoryFileProvider // clock is a realtime clock used to set timestamps in file operations. clock time.Clock @@ -106,8 +107,8 @@ type FilesystemOpts struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { + mfp := pgalloc.MemoryFileProviderFromContext(ctx) + if mfp == nil { panic("MemoryFileProviderFromContext returned nil") } @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } clock := time.RealtimeClockFromContext(ctx) fs := filesystem{ - memFile: memFileProvider.MemoryFile(), + mfp: mfp, clock: clock, devMinor: devMinor, } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 0ca750281..e265be0ee 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -6,6 +6,7 @@ go_library( name = "verity", srcs = [ "filesystem.go", + "save_restore.go", "verity.go", ], visibility = ["//pkg/sentry:internal"], @@ -15,6 +16,7 @@ go_library( "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", @@ -38,10 +40,12 @@ go_test( "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 03da505e1..4e8d63d51 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // Verify returns with success. var buf bytes.Buffer if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize()), + ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, }); err != nil && err != io.EOF { - return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err @@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat }) if err == syserror.ENODATA { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat var buf bytes.Buffer params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - ReadOffset: 0, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), + ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, Expected: d.hash, @@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID + d.size = uint32(size) return nil } // Preconditions: fs.renameMu must be locked. d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { + // If verity is enabled on child, we should check again whether + // the file and the corresponding Merkle tree are as expected, + // in order to catch deletion/renaming after the last time it's + // accessed. + if child.verityEnabled() { + vfsObj := fs.vfsfs.VirtualFilesystem() + // Get the path to the child dentry. This is only used + // to provide path information in failure case. + path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + // The file was previously accessed. If the + // file does not exist now, it indicates an + // unexpected modification to the file system. + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + } + if err != nil { + return nil, err + } + defer childVD.DecRef(ctx) + + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + // The Merkle tree file was previous accessed. If it + // does not exist now, it indicates an unexpected + // modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + } + if err != nil { + return nil, err + } + + defer childMerkleVD.DecRef(ctx) + } + // If enabling verification on files/directories is not allowed // during runtime, all cached children are already verified. If // runtime enable is allowed and the parent directory is @@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childFilename := fspath.Parse(name) - childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: childFilename, - }, &vfs.GetDentryOptions{}) - + childVD, childErr := parent.getLowerAt(ctx, vfsObj, name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childErr != nil && childErr != syserror.ENOENT { @@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) } - childMerkleFilename := merklePrefix + name - childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - + childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { @@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: parent.lowerVD, Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), + Path: fspath.Parse(merklePrefix + name), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, Mode: 0644, @@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } childMerkleFD.DecRef(ctx) - childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { return nil, err } @@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { // Both the child and the corresponding Merkle tree are missing. @@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // TODO(b/167752508): Investigate possible ways to differentiate // cases that both files are deleted from cases that they never // exist in the file system. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go new file mode 100644 index 000000000..46b064342 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 8dc9e26bc..d24c839bb 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -23,6 +23,7 @@ package verity import ( "fmt" + "math" "strconv" "sync/atomic" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -41,32 +43,62 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Name is the default filesystem name. -const Name = "verity" +const ( + // Name is the default filesystem name. + Name = "verity" -// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle -// tree file for "/foo" is "/.merkle.verity.foo". -const merklePrefix = ".merkle.verity." + // merklePrefix is the prefix of the Merkle tree files. For example, the Merkle + // tree file for "/foo" is "/.merkle.verity.foo". + merklePrefix = ".merkle.verity." -// merkleoffsetInParentXattr is the extended attribute name specifying the -// offset of child hash in its parent's Merkle tree. -const merkleOffsetInParentXattr = "user.merkle.offset" + // merkleOffsetInParentXattr is the extended attribute name specifying the + // offset of the child hash in its parent's Merkle tree. + merkleOffsetInParentXattr = "user.merkle.offset" -// merkleSizeXattr is the extended attribute name specifying the size of data -// hashed by the corresponding Merkle tree. For a file, it's the size of the -// whole file. For a directory, it's the size of all its children's hashes. -const merkleSizeXattr = "user.merkle.size" + // merkleSizeXattr is the extended attribute name specifying the size of data + // hashed by the corresponding Merkle tree. For a regular file, this is the + // file size. For a directory, this is the size of all its children's hashes. + merkleSizeXattr = "user.merkle.size" -// sizeOfStringInt32 is the size for a 32 bit integer stored as string in -// extended attributes. The maximum value of a 32 bit integer is 10 digits. -const sizeOfStringInt32 = 10 + // sizeOfStringInt32 is the size for a 32 bit integer stored as string in + // extended attributes. The maximum value of a 32 bit integer has 10 digits. + sizeOfStringInt32 = 10 +) -// noCrashOnVerificationFailure indicates whether the sandbox should panic -// whenever verification fails. If true, an error is returned instead of -// panicking. This should only be set for tests. -// TOOD(b/165661693): Decide whether to panic or return error based on this -// flag. -var noCrashOnVerificationFailure bool +var ( + // noCrashOnVerificationFailure indicates whether the sandbox should panic + // whenever verification fails. If true, an error is returned instead of + // panicking. This should only be set for tests. + // + // TODO(b/165661693): Decide whether to panic or return error based on this + // flag. + noCrashOnVerificationFailure bool + + // verityMu synchronizes concurrent operations that enable verity and perform + // verification checks. + verityMu sync.RWMutex +) + +// HashAlgorithm is a type specifying the algorithm used to hash the file +// content. +type HashAlgorithm int + +// Currently supported hashing algorithms include SHA256 and SHA512. +const ( + SHA256 HashAlgorithm = iota + SHA512 +) + +func (alg HashAlgorithm) toLinuxHashAlg() int { + switch alg { + case SHA256: + return linux.FS_VERITY_HASH_ALG_SHA256 + case SHA512: + return linux.FS_VERITY_HASH_ALG_SHA512 + default: + return 0 + } +} // FilesystemType implements vfs.FilesystemType. // @@ -97,6 +129,10 @@ type filesystem struct { // stores the root hash of the whole file system in bytes. rootDentry *dentry + // alg is the algorithms used to hash the files in the verity file + // system. + alg HashAlgorithm + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -125,6 +161,10 @@ type InternalFilesystemOptions struct { // LowerName is the name of the filesystem wrapped by verity fs. LowerName string + // Alg is the algorithms used to hash the files in the verity file + // system. + Alg HashAlgorithm + // RootHash is the root hash of the overall verity file system. RootHash []byte @@ -153,10 +193,10 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. -func alertIntegrityViolation(err error, msg string) error { +// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +func alertIntegrityViolation(msg string) error { if noCrashOnVerificationFailure { - return err + return syserror.EIO } panic(msg) } @@ -183,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs := &filesystem{ creds: creds.Fork(), + alg: iopts.Alg, lowerMount: mnt, allowRuntimeEnable: iopts.AllowRuntimeEnable, } @@ -236,7 +277,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -289,11 +330,12 @@ type dentry struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // mode, uid and gid are the file mode, owner, and group of the file in - // the underlying file system. + // mode, uid, gid and size are the file mode, owner, group, and size of + // the file in the underlying file system. mode uint32 uid uint32 gid uint32 + size uint32 // parent is the dentry corresponding to this dentry's parent directory. // name is this dentry's name in parent. If this dentry is a filesystem @@ -331,22 +373,25 @@ func (fs *filesystem) newDentry() *dentry { fs: fs, } d.vfsd.Init(d) + refsvfs2.Register(d) return d } // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -354,15 +399,27 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("verity.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("verity.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. func (d *dentry) checkDropLocked(ctx context.Context) { @@ -393,23 +450,36 @@ func (d *dentry) destroyLocked(ctx context.Context) { if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } - if d.lowerMerkleVD.Ok() { d.lowerMerkleVD.DecRef(ctx) } - if d.parent != nil { d.parent.dirMu.Lock() if !d.vfsd.IsDead() { delete(d.parent.children, d.name) } d.parent.dirMu.Unlock() - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("verity.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "verity.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -448,6 +518,16 @@ func (d *dentry) verityEnabled() bool { return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// getLowerAt returns the dentry in the underlying file system, which is +// represented by filename relative to d. +func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { + return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(filename), + }, &vfs.GetDentryOptions{}) +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -489,6 +569,10 @@ type fileDescription struct { // directory that contains the current file/directory. This is only used // if allowRuntimeEnable is set to true. parentMerkleWriter *vfs.FileDescription + + // off is the file offset. off is protected by mu. + mu sync.Mutex `state:"nosave"` + off int64 } // Release implements vfs.FileDescriptionImpl.Release. @@ -524,6 +608,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + n := int64(0) + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + n = fd.off + case linux.SEEK_END: + n = int64(fd.d.size) + default: + return 0, syserror.EINVAL + } + if offset > math.MaxInt64-n { + return 0, syserror.EINVAL + } + offset += n + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + // generateMerkle generates a Merkle tree file for fd. If fd points to a file // /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash // of the generated Merkle tree and the data size is returned. If fd points to @@ -546,6 +656,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -611,7 +723,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkle(ctx) @@ -657,6 +769,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // measureVerity returns the hash of fd, saved in verityDigest. func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } var metadata linux.DigestMetadata // If allowRuntimeEnable is true, an empty fd.d.hash indicates that @@ -667,7 +782,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found") + return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -702,6 +817,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag } t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -722,6 +840,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Implement Read with PRead by setting offset. + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in @@ -742,7 +870,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -752,7 +880,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := vfs.FileReadWriteSeeker{ @@ -766,25 +894,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, DataAndTreeInSameFile: false, }) if err != nil { - return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index e301d35f5..b2da9dd96 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -25,10 +25,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -41,11 +43,18 @@ const maxDataSize = 100000 // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("testutil.Boot: %v", err) + } + + ctx := k.SupervisorContext() + rand.Seed(time.Now().UnixNano()) vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -61,22 +70,33 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v InternalData: InternalFilesystemOptions{ RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", + Alg: hashAlg, AllowRuntimeEnable: true, NoCrashOnVerificationFailure: true, }, }, }) if err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) } root := mntns.Root() root.IncRef() + + // Use lowerRoot in the task as we modify the lower file system + // directly in many tests. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) + if err != nil { + t.Fatalf("testutil.CreateTask: %v", err) + } + t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, nil + return vfsObj, root, task, nil } // newFileFD creates a new file in the verity mount, and returns the FD. The FD @@ -142,207 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } +var hashAlgs = []HashAlgorithm{SHA256, SHA512} + // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Ensure that the corresponding Merkle tree file is created. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } } +} - // Ensure the root merkle tree file is created. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + rootMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) +// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity +// file succeeds after enabling verity for it. +func TestPReadUnmodifiedFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } -// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file -// succeeds after enabling verity for it. +// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity +// file succeeds after enabling verity for it. func TestReadUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.PRead: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Ensure reopening the verity enabled file succeeds. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != nil { - t.Errorf("reopen enabled file failed: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } } } -// TestModifiedFileFails ensures that read from a modified verity file fails. -func TestModifiedFileFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) +// TestPReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestPReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded, expected failure") + } } +} - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified file") +// TestReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") + } } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD - - lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerMerkleVD, - Start: lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the Merkle tree file. - stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - merkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - // Confirm that read from a file with modified Merkle tree fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) - t.Fatalf("fd.PRead succeeded with modified Merkle file") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } } } @@ -350,142 +459,267 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD - - parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: parentLowerMerkleVD, - Start: parentLowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the parent Merkle tree file. - // This parent directory contains only one child, so any random - // modification in the parent Merkle tree should cause verification - // failure when opening the child file. - stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - parentMerkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - parentLowerMerkleFD.DecRef(ctx) - - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err == nil { - t.Errorf("OpenAt file with modified parent Merkle succeeded") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } } } // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { - t.Errorf("fd.Stat: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms stat succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { + t.Errorf("fd.Stat: %v", err) + } } } // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + lowerFD := fd.Impl().(*fileDescription).lowerFD + // Change the stat of the underlying file, and check that stat fails. + if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: uint32(linux.STATX_MODE), + Mode: 0777, + }, + }); err != nil { + t.Fatalf("lowerFD.SetStat: %v", err) + } - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { + t.Errorf("fd.Stat succeeded when it should fail") + } } +} - lowerFD := fd.Impl().(*fileDescription).lowerFD - // Change the stat of the underlying file, and check that stat fails. - if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: uint32(linux.STATX_MODE), - Mode: 0777, +// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed +// verity enabled file or the corresponding Merkle tree file fails with the +// verify error. +func TestOpenDeletedFileFails(t *testing.T) { + testCases := []struct { + // Tests removing files is remove is true. Otherwise tests + // renaming files. + remove bool + // The original file is removed/renamed if changeFile is true. + changeFile bool + // The Merkle tree file is removed/renamed if changeMerkleFile + // is true. + changeMerkleFile bool + }{ + { + remove: true, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: true, + changeFile: false, + changeMerkleFile: true, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, }, - }); err != nil { - t.Fatalf("lowerFD.SetStat: %v", err) } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { - t.Errorf("fd.Stat succeeded when it should fail") + for _, tc := range testCases { + t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.remove { + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } else { + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + } + }) } } diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD index 364a78306..db3b0d0a0 100644 --- a/pkg/sentry/hostfd/BUILD +++ b/pkg/sentry/hostfd/BUILD @@ -6,10 +6,12 @@ go_library( name = "hostfd", srcs = [ "hostfd.go", + "hostfd_linux.go", "hostfd_unsafe.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/log", "//pkg/safemem", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go new file mode 100644 index 000000000..1cabc848f --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_linux.go @@ -0,0 +1,18 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostfd + +// maxIov is the maximum permitted size of a struct iovec array. +const maxIov = 1024 // UIO_MAXIOV diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go index cd4dc67fb..694371b1c 100644 --- a/pkg/sentry/hostfd/hostfd_unsafe.go +++ b/pkg/sentry/hostfd/hostfd_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" ) @@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6 } } else { iovs := safemem.IovecsFromBlockSeq(dsts) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { @@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint } } else { iovs := safemem.IovecsFromBlockSeq(srcs) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c0de72eef..90dd4a047 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -79,7 +79,7 @@ go_template_instance( out = "fd_table_refs.go", package = "kernel", prefix = "FDTable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FDTable", }, @@ -90,7 +90,7 @@ go_template_instance( out = "fs_context_refs.go", package = "kernel", prefix = "FSContext", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FSContext", }, @@ -101,7 +101,7 @@ go_template_instance( out = "ipc_namespace_refs.go", package = "kernel", prefix = "IPCNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "IPCNamespace", }, @@ -112,7 +112,7 @@ go_template_instance( out = "process_group_refs.go", package = "kernel", prefix = "ProcessGroup", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ProcessGroup", }, @@ -123,7 +123,7 @@ go_template_instance( out = "session_refs.go", package = "kernel", prefix = "Session", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Session", }, @@ -229,7 +229,7 @@ go_library( "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 1b9721534..0ddbe5ff6 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs_vfs2" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) @@ -27,7 +27,7 @@ import ( // +stateify savable type abstractEndpoint struct { ep transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter name string ns *AbstractSocketNamespace } @@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { // its backing socket. type boundEndpoint struct { transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. @@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() @@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran // Remove removes the specified socket at name from the abstract socket // namespace, if it has not yet been replaced. -func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { +func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ec7344cd..7aba31587 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() - f.init() // Initialize table. + f.initNoLeakCheck() // Initialize table. f.used = 0 for fd, d := range m { if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { @@ -240,6 +240,10 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() + vd := fileVFS2.VirtualDentry() + if vd.Dentry() == nil { + panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2)) + } name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index da79e6627..3476551f3 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -31,14 +31,21 @@ type descriptorTable struct { slice unsafe.Pointer `state:".(map[int32]*descriptor)"` } -// init initializes the table. +// initNoLeakCheck initializes the table without enabling leak checking. // -// TODO(gvisor.dev/1486): Enable leak check for FDTable. -func (f *FDTable) init() { +// This is used when loading an FDTable after S/R, during which the ref count +// object itself will enable leak checking if necessary. +func (f *FDTable) initNoLeakCheck() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) } +// init initializes the table with leak checking. +func (f *FDTable) init() { + f.initNoLeakCheck() + f.EnableLeakCheck() +} + // get gets a file entry. // // The boolean indicates whether this was in range. diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index d46d1e1c1..41fb2a784 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext { f.root.IncRef() } - return &FSContext{ + ctx := &FSContext{ cwd: f.cwd, root: f.root, cwdVFS2: f.cwdVFS2, rootVFS2: f.rootVFS2, umask: f.umask, } + ctx.EnableLeakCheck() + return ctx } // WorkingDirectory returns the current working directory. @@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() - f.cwd.IncRef() + if f.cwd != nil { + f.cwd.IncRef() + } return f.cwd } // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.cwdVFS2.IncRef() + if f.cwdVFS2.Ok() { + f.cwdVFS2.IncRef() + } return f.cwdVFS2 } @@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.rootVFS2.IncRef() + if f.rootVFS2.Ok() { + f.rootVFS2.IncRef() + } return f.rootVFS2 } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 3f34ee0db..b87e40dd1 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } -// DecRef implements refs_vfs2.RefCounter.DecRef. +// DecRef implements refsvfs2.RefCounter.DecRef. func (i *IPCNamespace) DecRef(ctx context.Context) { i.IPCNamespaceRefs.DecRef(func() { i.shms.Release(ctx) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0eb2bf7bd..9b2be44d4 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -430,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w wire.Writer) error { +func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { saveStart := time.Now() - ctx := k.SupervisorContext() // Do not allow other Kernel methods to affect it while it's being saved. k.extMu.Lock() @@ -446,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error { k.mf.StartEvictions() k.mf.WaitForEvictions() - // Flush write operations on open files so data reaches backing storage. - // This must come after MemoryFile eviction since eviction may cause file - // writes. - if err := k.tasks.flushWritesToFiles(ctx); err != nil { - return err - } + if VFS2Enabled { + // Discard unsavable mappings, such as those for host file descriptors. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) + // Prepare filesystems for saving. This must be done after + // invalidateUnsavableMappings(), since dropping memory mappings may + // affect filesystem state (e.g. page cache reference counts). + if err := k.vfs.PrepareSave(ctx); err != nil { + return err + } + } else { + // Flush cached file writes to backing storage. This must come after + // MemoryFile eviction since eviction may cause file writes. + if err := k.flushWritesToFiles(ctx); err != nil { + return err + } - // Clear the dirent cache before saving because Dirents must be Loaded in a - // particular order (parents before children), and Loading dirents from a cache - // breaks that order. - if err := k.flushMountSourceRefs(ctx); err != nil { - return err - } + // Remove all epoll waiter objects from underlying wait queues. + // NOTE: for programs to resume execution in future snapshot scenarios, + // we will need to re-establish these waiter objects after saving. + k.tasks.unregisterEpollWaiters(ctx) - // Ensure that all inode and mount release operations have completed. - fs.AsyncBarrier() + // Clear the dirent cache before saving because Dirents must be Loaded in a + // particular order (parents before children), and Loading dirents from a cache + // breaks that order. + if err := k.flushMountSourceRefs(ctx); err != nil { + return err + } - // Once all fs work has completed (flushed references have all been released), - // reset mount mappings. This allows individual mounts to save how inodes map - // to filesystem resources. Without this, fs.Inodes cannot be restored. - fs.SaveInodeMappings() + // Ensure that all inode and mount release operations have completed. + fs.AsyncBarrier() - // Discard unsavable mappings, such as those for host file descriptors. - // This must be done after waiting for "asynchronous fs work", which - // includes async I/O that may touch application memory. - if err := k.invalidateUnsavableMappings(ctx); err != nil { - return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + // Once all fs work has completed (flushed references have all been released), + // reset mount mappings. This allows individual mounts to save how inodes map + // to filesystem resources. Without this, fs.Inodes cannot be restored. + fs.SaveInodeMappings() + + // Discard unsavable mappings, such as those for host file descriptors. + // This must be done after waiting for "asynchronous fs work", which + // includes async I/O that may touch application memory. + // + // TODO(gvisor.dev/issue/1624): This rationale is believed to be + // obsolete since AIO callbacks are now waited-for by Kernel.Pause(), + // but this order is conservatively retained for VFS1. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } } // Save the CPUID FeatureSet before the rest of the kernel so we can @@ -486,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { + if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - stats, err := state.Save(k.SupervisorContext(), w, k) + stats, err := state.Save(ctx, w, k) if err != nil { return err } @@ -502,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { + if err := k.mf.SaveTo(ctx, w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -514,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. +// +// Preconditions: !VFS2Enabled. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { - if VFS2Enabled { - return nil // Not relevant. - } - // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -561,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi return err } -func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - - return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { +// Preconditions: !VFS2Enabled. +func (k *Kernel) flushWritesToFiles(ctx context.Context) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -589,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: The kernel must be paused. -func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { - invalidated := make(map[*mm.MemoryManager]struct{}) - k.tasks.mu.RLock() - defer k.tasks.mu.RUnlock() - for t := range k.tasks.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { - return err - } - invalidated[mm] = struct{}{} - } - } - // I really wish we just had a sync.Map of all MMs... - if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { - return err - } - } - } - return nil -} - +// Preconditions: !VFS2Enabled. func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return - } - ts.mu.RLock() defer ts.mu.RUnlock() @@ -644,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { } } +// Preconditions: The kernel must be paused. +func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { + invalidated := make(map[*mm.MemoryManager]struct{}) + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + for t := range k.tasks.Root.tids { + // We can skip locking Task.mu here since the kernel is paused. + if mm := t.tc.MemoryManager; mm != nil { + if _, ok := invalidated[mm]; !ok { + if err := mm.InvalidateUnsavable(ctx); err != nil { + return err + } + invalidated[mm] = struct{}{} + } + } + // I really wish we just had a sync.Map of all MMs... + if r, ok := t.runState.(*runSyscallAfterExecStop); ok { + if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { + return err + } + } + } + return nil +} + // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -656,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { + if _, err := state.Load(ctx, r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -671,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the kernel state. kernelStart := time.Now() - stats, err := state.Load(k.SupervisorContext(), r, k) + stats, err := state.Load(ctx, r, k) if err != nil { return err } @@ -684,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { + if err := k.mf.LoadFrom(ctx, r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -696,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock net.Resume() } - // Ensure that all pending asynchronous work is complete: - // - namedpipe opening - // - inode file opening - if err := fs.AsyncErrorBarrier(); err != nil { - return err + if VFS2Enabled { + if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil { + return err + } + } else { + // Ensure that all pending asynchronous work is complete: + // - namedpipe opening + // - inode file opening + if err := fs.AsyncErrorBarrier(); err != nil { + return err + } } tcpip.AsyncLoading.Wait() diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ce0db5583..d6fb0fdb8 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) type sleeper struct { @@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag d := fs.NewDirent(ctx, inode, "pipe") file, err := n.GetFile(ctx, d, flags) if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) + t.Errorf("open with flags %+v failed: %v", flags, err) + return nil, err } if doneChan != nil { doneChan <- struct{}{} @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 67beb0ad6..b989e14c7 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -26,18 +26,27 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) const ( // MinimumPipeSize is a hard limit of the minimum size of a pipe. - MinimumPipeSize = 64 << 10 + // It corresponds to fs/pipe.c:pipe_min_size. + MinimumPipeSize = usermem.PageSize + + // MaximumPipeSize is a hard limit on the maximum size of a pipe. + // It corresponds to fs/pipe.c:pipe_max_size. + MaximumPipeSize = 1048576 // DefaultPipeSize is the system-wide default size of a pipe in bytes. - DefaultPipeSize = MinimumPipeSize + // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS. + DefaultPipeSize = 16 * usermem.PageSize - // MaximumPipeSize is a hard limit on the maximum size of a pipe. - MaximumPipeSize = 8 << 20 + // atomicIOBytes is the maximum number of bytes that the pipe will + // guarantee atomic reads or writes atomically. + // It corresponds to limits.h:PIPE_BUF. + atomicIOBytes = 4096 ) // Pipe is an encapsulation of a platform-independent pipe. @@ -53,12 +62,6 @@ type Pipe struct { // This value is immutable. isNamed bool - // atomicIOBytes is the maximum number of bytes that the pipe will - // guarantee atomic reads or writes atomically. - // - // This value is immutable. - atomicIOBytes int64 - // The number of active readers for this pipe. // // Access atomically. @@ -94,47 +97,34 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // -// N.B. The size and atomicIOBytes will be bounded. -func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +// N.B. The size will be bounded. +func NewPipe(isNamed bool, sizeBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } var p Pipe - initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + initPipe(&p, isNamed, sizeBytes) return &p } -func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { +func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } pipe.isNamed = isNamed pipe.max = sizeBytes - pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. -func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) +func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) { + p := NewPipe(false /* isNamed */, sizeBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { wanted := ops.left() avail := p.max - p.view.Size() if wanted > avail { - if wanted <= p.atomicIOBytes { + if wanted <= atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index fe97e9800..3dd739080 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -26,7 +26,7 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) { const capacity = MinimumPipeSize ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) + r, w := NewConnectedPipe(ctx, capacity) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { const atomicIOBytes = 2 ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) + r, w := NewConnectedPipe(ctx, atomicIOBytes) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { } } if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) + t.Errorf("Readv: got unexpected error %v", err) + return } } }() diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 1a152142b..7b23cbe86 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -33,6 +33,8 @@ import ( // VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should // not be copied. +// +// +stateify savable type VFSPipe struct { // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -52,9 +54,9 @@ type VFSPipe struct { } // NewVFSPipe returns an initialized VFSPipe. -func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { +func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { var vp VFSPipe - initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes) + initPipe(&vp.pipe, isNamed, sizeBytes) return &vp } @@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements // non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to // other FileDescriptions for splice(2) and tee(2). +// +// +stateify savable type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1145faf13..1abfe2201 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { + if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } _, err := word.CopyOut(t, data) @@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: word := t.Arch().Native(uintptr(data)) - _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) + _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c00fa1138..b99c0bffa 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -103,6 +103,7 @@ type waiter struct { waiterEntry // value represents how much resource the waiter needs to wake up. + // The value is either 0 or negative. value int16 ch chan struct{} } @@ -283,6 +284,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File return nil } +// GetStat extracts semid_ds information from the set. +func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + ds := &linux.SemidDS{ + SemPerm: linux.IPCPerm{ + Key: uint32(s.key), + UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)), + Mode: uint16(s.perms.LinuxMode()), + Seq: 0, // IPC sequence not supported. + }, + SemOTime: s.opTime.TimeT(), + SemCTime: s.changeTime.TimeT(), + SemNSems: uint64(s.Size()), + } + return ds, nil +} + // SetVal overrides a semaphore value, waking up waiters as needed. func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error { if val < 0 || val > valueMax { @@ -320,7 +348,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti } for _, val := range vals { - if val < 0 || val > valueMax { + if val > valueMax { return syserror.ERANGE } } @@ -396,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) { return sem.pid, nil } +func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // The calling process must have read permission on the semaphore set. + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return 0, syserror.EACCES + } + + sem := s.findSem(num) + if sem == nil { + return 0, syserror.ERANGE + } + var cnt uint16 + for w := sem.waiters.Front(); w != nil; w = w.Next() { + if pred(w) { + cnt++ + } + } + return cnt, nil +} + +// CountZeroWaiters returns number of waiters waiting for the sem's value to increase. +func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value == 0 + }) +} + +// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero. +func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value < 0 + }) +} + // ExecuteOps attempts to execute a list of operations to the set. It only // succeeds when all operations can be applied. No changes are made if it fails. // @@ -548,11 +612,18 @@ func (s *Set) destroy() { } } +func abs(val int16) int16 { + if val < 0 { + return -val + } + return val +} + // wakeWaiters goes over all waiters and checks which of them can be notified. func (s *sem) wakeWaiters() { // Note that this will release all waiters waiting for 0 too. for w := s.waiters.Front(); w != nil; { - if s.value < w.value { + if s.value < abs(w.value) { // Still blocked, skip it. w = w.Next() continue diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index df5c8421b..5bddb0a36 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session { // // If this group isn't visible in this namespace, zero will be returned. It is // the callers responsibility to check that before using this function. -func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sids[s] +func (ns *PIDNamespace) IDOfSession(s *Session) SessionID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sids[s] } // SessionWithID returns the Session with the given ID in the PID namespace ns, // or nil if that given ID is not defined in this namespace. // // A reference is not taken on the session. -func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sessions[id] +func (ns *PIDNamespace) SessionWithID(id SessionID) *Session { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sessions[id] } // ProcessGroup returns the ThreadGroup's ProcessGroup. @@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup { // IDOfProcessGroup returns the process group assigned to pg in PID namespace ns. // // The same constraints apply as IDOfSession. -func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.pgids[pg] +func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.pgids[pg] } // ProcessGroupWithID returns the ProcessGroup with the given ID in the PID // namespace ns, or nil if that given ID is not defined in this namespace. // // A reference is not taken on the process group. -func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.processGroups[id] +func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.processGroups[id] } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index f8a382fd8..80a592c8f 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "shm_refs.go", package = "shm", prefix = "Shm", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Shm", }, @@ -27,7 +27,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 682080c14..527344162 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } if opts.ChildSetTID { ctid := nt.ThreadID() - ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) + ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index ce134bf54..94dabbcd8 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,7 +18,8 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp }, nil } -// copyContext implements marshal.CopyContext. It wraps a task to allow copying -// memory to and from the task memory with custom usermem.IOOpts. -type copyContext struct { - *Task +type taskCopyContext struct { + ctx context.Context + t *Task opts usermem.IOOpts } -// AsCopyContext wraps the task and returns it as CopyContext. -func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { - return ©Context{t, opts} +// CopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. +func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext { + return &taskCopyContext{ + ctx: ctx, + t: t, + opts: opts, + } +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte { + if ctxTask, ok := cc.ctx.(*Task); ok { + return ctxTask.CopyScratchBuffer(size) + } + return make([]byte, size) +} + +func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) { + cc.t.mu.Lock() + tmm := cc.t.MemoryManager() + cc.t.mu.Unlock() + if !tmm.IncUsers() { + return nil, syserror.EFAULT + } + return tmm, nil +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyIn(cc.ctx, addr, dst, cc.opts) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyOut(cc.ctx, addr, src, cc.opts) +} + +type ownTaskCopyContext struct { + t *Task + opts usermem.IOOpts +} + +// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. The returned CopyContext may only be used by t's task +// goroutine. +// +// Since t already implements marshal.CopyContext, this is only needed to +// override the usermem.IOOpts used for the copy. +func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext { + return &ownTaskCopyContext{ + t: t, + opts: opts, + } } -// CopyInString copies a string in from the task's memory. -func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { - return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte { + return cc.t.CopyScratchBuffer(size) } -// CopyInBytes copies task memory into dst from an IO context. -func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { - return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts) } -// CopyOutBytes copies src into task memoryfrom an IO context. -func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { - return t.MemoryManager().CopyOut(t, addr, src, t.opts) +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts) } diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 9bc452e67..9e5c2d26f 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error { } if old != v.seq { - return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq) + return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq) } v.seq = next diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b4a47ccca..6dbeccfe2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -78,7 +78,7 @@ go_template_instance( out = "aio_mappable_refs.go", package = "mm", prefix = "aioMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "aioMappable", }, @@ -89,7 +89,7 @@ go_template_instance( out = "special_mappable_refs.go", package = "mm", prefix = "SpecialMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SpecialMappable", }, @@ -127,6 +127,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safecopy", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0a54dd30d..acad4c793 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) { c.runData.requestInterruptWindow = 0 } +// bluepillSigBus is reponsible for injecting NMI to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_NMI, 0); errno != 0 { + throw("NMI injection failed") + } +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 58f3d6fdd..965ad66b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -27,15 +27,20 @@ var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL - // vcpuSErr is the event of system error. - vcpuSErr = kvmVcpuEvents{ + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce = kvmVcpuEvents{ exception: exception{ sErrPending: 1, - sErrHasEsr: 0, - pad: [6]uint8{0, 0, 0, 0, 0, 0}, - sErrEsr: 1, }, - rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } + + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, } ) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index b35c930e2..9433d4da5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { + uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + throw("sErr injection failed") + } +} + +// bluepillSigBus is reponsible for injecting sError to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { throw("sErr injection failed") } } diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index eb05950cd..75085ac6a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( // escapes: no. - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_NMI, 0); errno != 0 { - throw("NMI injection failed") - } + bluepillSigBus(c) continue // Rerun vCPU. default: throw("run failed") diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index dd45ad10b..5979aef97 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { // NewAddressSpace returns a new pagetable root. func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. - pageTables := pagetables.New(newAllocator()) - k.machine.mapUpperHalf(pageTables) + pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 84df0f878..b060d9544 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -38,6 +38,8 @@ const ( _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 + _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a + _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00 ) // Arm64: Architectural Feature Access Control Register EL1. @@ -149,6 +151,9 @@ const ( _ESR_SEGV_PEMERR_L1 = 0xd _ESR_SEGV_PEMERR_L2 = 0xe _ESR_SEGV_PEMERR_L3 = 0xf + + // Custom ISS field definitions for system error. + _ESR_ELx_SERR_NMI = 0x1 ) // Arm64: MMIO base address used to dispatch hypercalls. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 61ed24d01..e2fffc99b 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,6 +41,9 @@ type machine struct { // slots are currently being updated, and the caller should retry. nextSlot uint32 + // upperSharedPageTables tracks the read-only shared upper of all the pagetables. + upperSharedPageTables *pagetables.PageTables + // kernel is the set of global structures. kernel ring0.Kernel @@ -198,9 +202,7 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }, m.maxVCPUs) + m.kernel.Init(m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -212,6 +214,13 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Create the upper shared pagetables and kernel(sentry) pagetables. + m.upperSharedPageTables = pagetables.New(newAllocator()) + m.mapUpperHalf(m.upperSharedPageTables) + m.upperSharedPageTables.Allocator.(*allocator).base.Drain() + m.upperSharedPageTables.MarkReadOnlyShared() + m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -225,7 +234,6 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -625,3 +633,35 @@ func (c *vCPU) BounceToKernel() { func (c *vCPU) BounceToHost() { c.bounce(true) } + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c67127d95..8e03c310d 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error { } } -// setSystemTimeLegacy calibrates and sets an approximate system time. -func (c *vCPU) setSystemTimeLegacy() error { - const minIterations = 10 - minimum := uint64(0) - for iter := 0; ; iter++ { - // Try to set the TSC to an estimate of where it will be - // on the host during a "fast" system call iteration. - start := uint64(ktime.Rdtsc()) - if err := c.setTSC(start + (minimum / 2)); err != nil { - return err - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. - end := uint64(ktime.Rdtsc()) - if end < start { - continue // Totally bogus: unstable TSC? - } - current := end - start - if current < minimum || iter == 0 { - minimum = current // Set our new minimum. - } - // Is this past minIterations and within ~10% of minimum? - upperThreshold := (((minimum << 3) + minimum) >> 3) - if iter >= minIterations && current <= upperThreshold { - return nil - } - } -} - // nonCanonical generates a canonical address return. // //go:nosplit @@ -464,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } -var execRegions = func() (regions []region) { +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + // Map all the executible regions so that all the entry functions + // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { return } + if vr.accessType.Execute { - regions = append(regions, vr.region) + r := vr.region + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) } }) - return -}() - -func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - for _, r := range execRegions { - physical, length, ok := translateToPhysical(r.virtual) - if !ok || length < r.length { - panic("impossilbe translation") - } - pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|r.virtual), - r.length, - pagetables.MapOpts{AccessType: usermem.Execute}, - physical) - } for start, end := range m.kernel.EntryRegions() { regionLen := end - start physical, length, ok := translateToPhysical(start) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a163f956d..fd92c3873 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error { } c.floatingPointState = arch.NewFloatingPointData() + + return c.setSystemTime() +} + +// setTSC sets the counter Virtual Offset. +func (c *vCPU) setTSC(value uint64) error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + data = uint64(value) + + if err := c.setOneRegister(®); err != nil { + return err + } + return nil } +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + return c.setSystemTimeLegacy() +} + //go:nosplit func (c *vCPU) loadSegments(tid uint64) { // TODO(gvisor.dev/issue/1238): TLS is not supported. @@ -197,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) } else if !ring0.IsCanonical(regs.Sp) { - return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info) } // Assign PCIDs. @@ -233,10 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) + case ring0.El0ErrNMI: + return c.fault(int32(syscall.SIGBUS), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef, - ring0.El1Sync_undef: + case ring0.El0SyncUndef: + return c.fault(int32(syscall.SIGILL), info) + case ring0.El1SyncUndef: *info = arch.SignalInfo{ Signo: int32(syscall.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 87a573cc4..327d48465 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -58,46 +58,55 @@ type Vector uintptr // Exception vectors. const ( - El1SyncInvalid = iota - El1IrqInvalid - El1FiqInvalid - El1ErrorInvalid + El1InvSync = iota + El1InvIrq + El1InvFiq + El1InvError + El1Sync El1Irq El1Fiq - El1Error + El1Err + El0Sync El0Irq El0Fiq - El0Error - El0Sync_invalid - El0Irq_invalid - El0Fiq_invalid - El0Error_invalid - El1Sync_da - El1Sync_ia - El1Sync_sp_pc - El1Sync_undef - El1Sync_dbg - El1Sync_inv - El0Sync_svc - El0Sync_da - El0Sync_ia - El0Sync_fpsimd_acc - El0Sync_sve_acc - El0Sync_sys - El0Sync_sp_pc - El0Sync_undef - El0Sync_dbg - El0Sync_inv + El0Err + + El0InvSync + El0InvIrq + El0InvFiq + El0InvErr + + El1SyncDa + El1SyncIa + El1SyncSpPc + El1SyncUndef + El1SyncDbg + El1SyncInv + + El0SyncSVC + El0SyncDa + El0SyncIa + El0SyncFpsimdAcc + El0SyncSveAcc + El0SyncSys + El0SyncSpPc + El0SyncUndef + El0SyncDbg + El0SyncInv + + El0ErrNMI + El0ErrBounce + _NR_INTERRUPTS ) // System call vectors. const ( - Syscall Vector = El0Sync_svc - PageFault Vector = El0Sync_da - VirtualizationException Vector = El0Error + Syscall Vector = El0SyncSVC + PageFault Vector = El0SyncDa + VirtualizationException Vector = El0ErrBounce ) // VirtualAddressBits returns the number bits available for virtual addresses. diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index e6daf24df..f9765771e 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -23,6 +23,9 @@ import ( // // This contains global state, shared by multiple CPUs. type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + KernelArchState } diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 00899273e..7a2275558 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -66,17 +66,9 @@ var ( KernelDataSegment SegmentDescriptor ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts - - // cpuEntries is array of kernelEntry for all cpus + // cpuEntries is array of kernelEntry for all cpus. cpuEntries []kernelEntry // globalIDT is our set of interrupt gates. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 508236e46..a014dcbc0 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -32,15 +32,8 @@ var ( KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts } // CPUArchState contains CPU-specific arch state. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2370a9276..f489ad352 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -288,6 +288,10 @@ #define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) #define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) +/* ISS field definitions for system error */ +#define ESR_ELx_SERR_MASK (0x1) +#define ESR_ELx_SERR_NMI (0x1) + // LOAD_KERNEL_ADDRESS loads a kernel address. #define LOAD_KERNEL_ADDRESS(from, to) \ MOVD from, to; \ @@ -366,6 +370,19 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// EXCEPTION_WITH_ERROR is a common exception handler function. +#define EXCEPTION_WITH_ERROR(user, vector) \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + WORD $0xd538601a; \ //MRS FAR_EL1, R26 + MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ + MOVD $user, R3; \ + MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. + MOVD $vector, R3; \ + MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ + MRS ESR_EL1, R3; \ + MOVD R3, CPU_ERROR_CODE(RSV_REG); \ + B ·kernelExitToEl1(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -503,6 +520,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1 MSR R1, ELR_EL1 + // restore sentry's tls. + MOVD CPU_REGISTERS+PTRACE_TLS(RSV_REG), R1 + MSR R1, TPIDR_EL0 + MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP @@ -659,21 +680,7 @@ el0_svc: el0_da: el0_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $1, R3 - MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - MRS ESR_EL1, R3 - MOVD R3, CPU_ERROR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, PageFault) el0_fpsimd_acc: B ·Shutdown(SB) @@ -688,10 +695,7 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - MOVD $El0Sync_undef, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, El0SyncUndef) el0_dbg: B ·Shutdown(SB) @@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0 TEXT ·El0_error(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + AND $ESR_ELx_SERR_MASK, R25, R24 + CMP $ESR_ELx_SERR_NMI, R24 + BEQ el0_nmi + B el0_bounce +el0_nmi: + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + WORD $0xd538601a //MRS FAR_EL1, R26 + + MOVD R26, CPU_FAULT_ADDR(RSV_REG) + + MOVD $1, R3 + MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. + + MOVD $El0ErrNMI, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) + +el0_bounce: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0 MOVD $VirtualizationException, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 264be23d3..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -16,11 +16,9 @@ package ring0 // Init initializes a new kernel. // -// N.B. that constraints on KernelOpts must be satisfied. -// //go:nosplit -func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { - k.init(opts, maxCPUs) +func (k *Kernel) Init(maxCPUs int) { + k.init(maxCPUs) } // Halt halts execution. diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 3a9dff4cc..b55dc29b3 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -24,10 +24,7 @@ import ( ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables - +func (k *Kernel) init(maxCPUs int) { entrySize := reflect.TypeOf(kernelEntry{}).Size() var ( entries []kernelEntry diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index b294ccc7c..6cbbf001f 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,9 +25,7 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables +func (k *Kernel) init(maxCPUs int) { } // init initializes architecture-specific state. diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 45eba960d..53bc3353c 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -47,43 +47,36 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) - fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) - fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) - fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) - fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err) fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) - fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err) - fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) - fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) - fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) - fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa) + fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa) + fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc) + fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef) + fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg) + fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv) - fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) - fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) - fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) - fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) - fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) - fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC) + fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa) + fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) + fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) + fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) + fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) + fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) + fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg) + fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv) - fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) - fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) - fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) - fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) - fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) - fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) - fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) - fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) - fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) - fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI) fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 7f18ac296..bc16a1622 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -30,6 +30,10 @@ type PageTables struct { Allocator Allocator // root is the pagetable root. + // + // For same archs such as amd64, the upper of the PTEs is cloned + // from and owned by upperSharedPageTables which are shared among + // many PageTables if upperSharedPageTables is not nil. root *PTEs // rootPhysical is the cached physical address of the root. @@ -39,15 +43,52 @@ type PageTables struct { // archPageTables includes architecture-specific features. archPageTables + + // upperSharedPageTables represents a read-only shared upper + // of the Pagetable. When it is not nil, the upper is not + // allowed to be modified. + upperSharedPageTables *PageTables + + // upperStart is the start address of the upper portion that + // are shared from upperSharedPageTables + upperStart uintptr + + // readOnlyShared indicates the Pagetables are read-only and + // own the ranges that are shared with other Pagetables. + readOnlyShared bool } -// New returns new PageTables. -func New(a Allocator) *PageTables { +// NewWithUpper returns new PageTables. +// +// upperSharedPageTables are used for mapping the upper of addresses, +// starting at upperStart. These pageTables should not be touched (as +// invalidations may be incorrect) after they are passed as an +// upperSharedPageTables. Only when all dependent PageTables are gone +// may they be used. The intenteded use case is for kernel page tables, +// which are static and fixed. +// +// Precondition: upperStart must be between canonical ranges. +// Precondition: upperStart must be pgdSize aligned. +// precondition: upperSharedPageTables must be marked read-only shared. +func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables { p := new(PageTables) p.Init(a) + if upperSharedPageTables != nil { + if !upperSharedPageTables.readOnlyShared { + panic("Only read-only shared pagetables can be used as upper") + } + p.upperSharedPageTables = upperSharedPageTables + p.upperStart = upperStart + p.cloneUpperShared() + } return p } +// New returns new PageTables. +func New(a Allocator) *PageTables { + return NewWithUpper(a, nil, 0) +} + // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. @@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true } // //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } if !opts.AccessType.Any() { return p.Unmap(addr, length) } @@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // True is returned iff there was a previous mapping in the range. // -// Precondition: addr & length must be page-aligned. +// Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack // //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } w := unmapWalker{ pageTables: p, visitor: unmapVisitor{ @@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) w.iterateRange(uintptr(addr), uintptr(addr)+1) return w.visitor.physical + offset, w.visitor.opts } + +// MarkReadOnlyShared marks the pagetables read-only and can be shared. +// +// It is usually used on the pagetables that are used as the upper +func (p *PageTables) MarkReadOnlyShared() { + p.readOnlyShared = true +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 520161755..a4e416af7 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -24,14 +24,6 @@ import ( // archPageTables is architecture-specific data. type archPageTables struct { - // root is the pagetable root for kernel space. - root *PTEs - - // rootPhysical is the cached physical address of the root. - // - // This is saved only to prevent constant translation. - rootPhysical uintptr - asid uint16 } @@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { // //go:nosplit func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { - return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset + return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset } // Bits in page table entries. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 0c153cf8c..e7ab887e5 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) { p.rootPhysical = p.Allocator.PhysicalFor(p.root) } +func pgdIndex(upperStart uintptr) uintptr { + if upperStart&(pgdSize-1) != 0 { + panic("upperStart should be pgd size aligned") + } + if upperStart >= upperBottom { + return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize + } + if upperStart < lowerTop { + return upperStart / pgdSize + } + panic("upperStart should be in canonical range") +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + start := pgdIndex(p.upperStart) + copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage]) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 1a49f12a2..5392bf27a 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -36,7 +36,7 @@ const ( pudSize = 1 << pudShift pgdSize = 1 << pgdShift - ttbrASIDOffset = 55 + ttbrASIDOffset = 48 ttbrASIDMask = 0xff entriesPerPage = 512 @@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) { p.Allocator = allocator p.root = p.Allocator.NewPTEs() p.rootPhysical = p.Allocator.PhysicalFor(p.root) - p.archPageTables.root = p.Allocator.NewPTEs() - p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + if p.upperStart != upperBottom { + panic("upperStart should be the same as upperBottom") + } + + // nothing to do for arm. } // PTEs is a collection of entries. diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go index c261d393a..157c9a7cc 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr { func (w *Walker) iterateRangeCanonical(start, end uintptr) { pgdEntryIndex := w.pageTables.root if start >= upperBottom { - pgdEntryIndex = w.pageTables.archPageTables.root + pgdEntryIndex = w.pageTables.upperSharedPageTables.root } for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index d9621968c..37d02948f 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -24,6 +24,8 @@ import ( ) // SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +// +// +stateify savable type SCMRightsVFS2 interface { transport.RightsControlMessage @@ -34,9 +36,11 @@ type SCMRightsVFS2 interface { Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) } -// RightsFiles represents a SCM_RIGHTS socket control message. A reference is -// maintained for each vfs.FileDescription and is release either when an FD is created or -// when the Release method is called. +// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference +// is maintained for each vfs.FileDescription and is release either when an FD +// is created or when the Release method is called. +// +// +stateify savable type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 163af329b..9a2cac40b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// +stateify savable type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &socketVFS2{ diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index faa61160e..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } if rawLine == "" { - return fmt.Errorf("Failed to get raw line") + return fmt.Errorf("failed to get raw line") } parts := strings.SplitN(rawLine, ":", 2) if len(parts) != 2 { - return fmt.Errorf("Failed to get prefix from: %q", rawLine) + return fmt.Errorf("failed to get prefix from: %q", rawLine) } sliceStat = toSlice(stat) fields := strings.Fields(strings.TrimSpace(parts[1])) if len(fields) != len(sliceStat) { - return fmt.Errorf("Failed to parse fields: %q", rawLine) + return fmt.Errorf("failed to parse fields: %q", rawLine) } if _, ok := stat.(*inet.StatSNMPTCP); ok { snmpTCP = true @@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) } if err != nil { - return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err) } } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 549787955..e0976fed0 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf // marshalTarget and unmarshalTarget can be used. type targetMaker interface { // id uniquely identifies the target. - id() stack.TargetID + id() targetID - // marshal converts from a stack.Target to an ABI struct. - marshal(target stack.Target) []byte + // marshal converts from a target to an ABI struct. + marshal(target target) []byte - // unmarshal converts from the ABI matcher struct to a stack.Target. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) + // unmarshal converts from the ABI matcher struct to a target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) } -// targetMakers maps the TargetID of supported targets to the targetMaker that +// A targetID uniquely identifies a target. +type targetID struct { + // name is the target name as stored in the xt_entry_target struct. + name string + + // networkProtocol is the protocol to which the target applies. + networkProtocol tcpip.NetworkProtocolNumber + + // revision is the version of the target. + revision uint8 +} + +// target extends a stack.Target, allowing it to be used with the extension +// system. The sentry only uses targets, never stack.Targets directly. +type target interface { + stack.Target + id() targetID +} + +// targetMakers maps the targetID of supported targets to the targetMaker that // marshals and unmarshals it. It is immutable after package initialization. -var targetMakers = map[stack.TargetID]targetMaker{} +var targetMakers = map[targetID]targetMaker{} func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { - tid := stack.TargetID{ - Name: name, - NetworkProtocol: netProto, - Revision: rev, + tid := targetID{ + name: name, + networkProtocol: netProto, + revision: rev, } if _, ok := targetMakers[tid]; !ok { return 0, false @@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8 // Return the highest supported revision unless rev is higher. for _, other := range targetMakers { otherID := other.id() - if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { - rev = uint8(otherID.Revision) + if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev { + rev = uint8(otherID.revision) } } return rev, true @@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) { targetMakers[tm.id()] = tm } -func marshalTarget(target stack.Target) []byte { - targetMaker, ok := targetMakers[target.ID()] +func marshalTarget(tgt stack.Target) []byte { + // The sentry only uses targets, never stack.Targets directly. + target := tgt.(target) + targetMaker, ok := targetMakers[target.id()] if !ok { - panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id())) } return targetMaker.marshal(target) } -func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { - tid := stack.TargetID{ - Name: target.Name.String(), - NetworkProtocol: filter.NetworkProtocol(), - Revision: target.Revision, +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) { + tid := targetID{ + name: target.Name.String(), + networkProtocol: filter.NetworkProtocol(), + revision: target.Revision, } targetMaker, ok := targetMakers[tid] if !ok { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index b560fae0d..70c561cce 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), false) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct. - entries, info := getEntries4(table, tablename) + entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 4253f7bf4..5dbb604f0 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), true) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct, which is the same in IPv4 and IPv6. - entries, info := getEntries6(table, tablename) + entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 904a12e38..b283d7229 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) { } } +// Table names. +const ( + natTable = "nat" + mangleTable = "mangle" + filterTable = "filter" +) + +// nameToID is immutable. +var nameToID = map[string]stack.TableID{ + natTable: stack.NATID, + mangleTable: stack.MangleID, + filterTable: stack.FilterID, +} + +// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for +// compatibility with netfilter extensions. +func DefaultLinuxTables() *stack.IPTables { + tables := stack.DefaultTables() + tables.VisitTargets(func(oldTarget stack.Target) stack.Target { + switch val := oldTarget.(type) { + case *stack.AcceptTarget: + return &acceptTarget{AcceptTarget: *val} + case *stack.DropTarget: + return &dropTarget{DropTarget: *val} + case *stack.ErrorTarget: + return &errorTarget{ErrorTarget: *val} + case *stack.UserChainTarget: + return &userChainTarget{UserChainTarget: *val} + case *stack.ReturnTarget: + return &returnTarget{ReturnTarget: *val} + case *stack.RedirectTarget: + return &redirectTarget{RedirectTarget: *val} + default: + panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val)) + } + }) + return tables +} + // GetInfo returns information about iptables. func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. @@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.FilterTable: + case filterTable: table = stack.EmptyFilterTable() - case stack.NATTable: + case natTable: table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) @@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } if offset == replace.Underflow[hook] { if !validUnderflow(table.Rules[ruleIdx], ipv6) { - nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx) return syserr.ErrInvalidArgument } table.Underflows[hk] = ruleIdx @@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) - + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case *stack.AcceptTarget, *stack.DropTarget: + case *acceptTarget, *dropTarget: return true default: return false @@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(*stack.AcceptTarget) + _, ok := rule.Target.(*acceptTarget) return ok } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 0e14447fe..f2653d523 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -26,6 +26,15 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port and/or IP for packets. +const RedirectTargetName = "REDIRECT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -52,25 +61,96 @@ func init() { }) } +// The stack package provides some basic, useful targets for us. The following +// types wrap them for compatibility with the extension system. + +type acceptTarget struct { + stack.AcceptTarget +} + +func (at *acceptTarget) id() targetID { + return targetID{ + networkProtocol: at.NetworkProtocol, + } +} + +type dropTarget struct { + stack.DropTarget +} + +func (dt *dropTarget) id() targetID { + return targetID{ + networkProtocol: dt.NetworkProtocol, + } +} + +type errorTarget struct { + stack.ErrorTarget +} + +func (et *errorTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: et.NetworkProtocol, + } +} + +type userChainTarget struct { + stack.UserChainTarget +} + +func (uc *userChainTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: uc.NetworkProtocol, + } +} + +type returnTarget struct { + stack.ReturnTarget +} + +func (rt *returnTarget) id() targetID { + return targetID{ + networkProtocol: rt.NetworkProtocol, + } +} + +type redirectTarget struct { + stack.RedirectTarget + + // addr must be (un)marshalled when reading and writing the target to + // userspace, but does not affect behavior. + addr tcpip.Address +} + +func (rt *redirectTarget) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rt.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (sm *standardTargetMaker) id() stack.TargetID { +func (sm *standardTargetMaker) id() targetID { // Standard targets have the empty string as a name and no revisions. - return stack.TargetID{ - NetworkProtocol: sm.NetworkProtocol, + return targetID{ + networkProtocol: sm.NetworkProtocol, } } -func (*standardTargetMaker) marshal(target stack.Target) []byte { + +func (*standardTargetMaker) marshal(target target) []byte { // Translate verdicts the same way as the iptables tool. var verdict int32 switch tg := target.(type) { - case *stack.AcceptTarget: + case *acceptTarget: verdict = -linux.NF_ACCEPT - 1 - case *stack.DropTarget: + case *dropTarget: verdict = -linux.NF_DROP - 1 - case *stack.ReturnTarget: + case *returnTarget: verdict = linux.NF_RETURN case *JumpTarget: verdict = int32(tg.Offset) @@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTStandardTarget { nflog("buf has wrong size for standard target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -114,20 +194,20 @@ type errorTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (em *errorTargetMaker) id() stack.TargetID { +func (em *errorTargetMaker) id() targetID { // Error targets have no revision. - return stack.TargetID{ - Name: stack.ErrorTargetName, - NetworkProtocol: em.NetworkProtocol, + return targetID{ + name: ErrorTargetName, + networkProtocol: em.NetworkProtocol, } } -func (*errorTargetMaker) marshal(target stack.Target) []byte { +func (*errorTargetMaker) marshal(target target) []byte { var errorName string switch tg := target.(type) { - case *stack.ErrorTarget: - errorName = stack.ErrorTargetName - case *stack.UserChainTarget: + case *errorTarget: + errorName = ErrorTargetName + case *userChainTarget: errorName = tg.Name default: panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) @@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte { }, } copy(xt.Name[:], errorName) - copy(xt.Target.Name[:], stack.ErrorTargetName) + copy(xt.Target.Name[:], ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTErrorTarget { nflog("buf has insufficient size for error target %d", len(buf)) return nil, syserr.ErrInvalidArgument } - var errorTarget linux.XTErrorTarget + var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &errTgt) // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named stack.ErrorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. + // * An actual error case. These rules have an error named + // ErrorTargetName. The last entry of the table is usually an error + // case to catch any packets that somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case stack.ErrorTargetName: - return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + switch name := errTgt.Name.String(); name { + case ErrorTargetName: + return &errorTarget{stack.ErrorTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }}, nil default: // User defined chain. - return &stack.UserChainTarget{ + return &userChainTarget{stack.UserChainTarget{ Name: name, NetworkProtocol: filter.NetworkProtocol(), - }, nil + }}, nil } } @@ -178,22 +259,22 @@ type redirectTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *redirectTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *redirectTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*redirectTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*redirectTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) // This is a redirect target named redirect xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(xt.Target.Name[:], stack.RedirectTargetName) + copy(xt.Target.Name[:], RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 @@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) < linux.SizeOfXTRedirectTarget { nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - var redirectTarget linux.XTRedirectTarget + var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &rt) // Copy linux.XTRedirectTarget to stack.RedirectTarget. - target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + target := redirectTarget{RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} // RangeSize should be 1. - nfRange := redirectTarget.NfRange + nfRange := rt.NfRange if nfRange.RangeSize != 1 { nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) return nil, syserr.ErrInvalidArgument @@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) target.Port = ntohs(nfRange.RangeIPV4.MinPort) return &target, nil @@ -264,15 +347,15 @@ type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *nfNATTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *nfNATTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*nfNATTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*nfNATTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ TargetSize: nfNATMarhsalledSize, @@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, } - copy(nt.Target.Name[:], stack.RedirectTargetName) - copy(nt.Range.MinAddr[:], rt.Addr) - copy(nt.Range.MaxAddr[:], rt.Addr) + copy(nt.Target.Name[:], RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.addr) + copy(nt.Range.MaxAddr[:], rt.addr) nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto @@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, nt) } -func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if size := nfNATMarhsalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument @@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta return nil, syserr.ErrInvalidArgument } - target := stack.RedirectTarget{ - NetworkProtocol: filter.NetworkProtocol(), - Addr: tcpip.Address(natRange.MinAddr[:]), - Port: ntohs(natRange.MinProto), + target := redirectTarget{ + RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Port: ntohs(natRange.MinProto), + }, + addr: tcpip.Address(natRange.MinAddr[:]), } return &target, nil @@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return &stack.AcceptTarget{NetworkProtocol: netProto}, nil + return &acceptTarget{stack.AcceptTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_DROP - 1: - return &stack.DropTarget{NetworkProtocol: netProto}, nil + return &dropTarget{stack.DropTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_QUEUE - 1: nflog("unsupported iptables verdict QUEUE") return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return &stack.ReturnTarget{NetworkProtocol: netProto}, nil + return &returnTarget{stack.ReturnTarget{ + NetworkProtocol: netProto, + }}, nil default: nflog("unknown iptables verdict %d", val) return nil, syserr.ErrInvalidArgument @@ -382,9 +473,9 @@ type JumpTarget struct { } // ID implements Target.ID. -func (jt *JumpTarget) ID() stack.TargetID { - return stack.TargetID{ - NetworkProtocol: jt.NetworkProtocol, +func (jt *JumpTarget) id() targetID { + return targetID{ + networkProtocol: jt.NetworkProtocol, } } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 844acfede..352c51390 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.TCPProtocolNumber { - return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber) } return &TCPMatcher{ diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 63201201c..c88d8268d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber) } return &UDPMatcher{ diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index e8930f031..f061c5d62 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c84d8bd7c..f4d034c13 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -36,9 +36,9 @@ type commandKind int const ( kindNew commandKind = 0x0 - kindDel = 0x1 - kindGet = 0x2 - kindSet = 0x3 + kindDel commandKind = 0x1 + kindGet commandKind = 0x2 + kindSet commandKind = 0x3 ) func typeKind(typ uint16) commandKind { @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } return nil } +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrBadLocalAddress + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { hdr := msg.Header() @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index c83b23242..461d524e5 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -37,6 +37,8 @@ import ( // to/from the kernel. // // SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 211f07947..86c634715 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1244,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ACCEPTCONN: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 4c6791fff..b0d9e4d9e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -35,6 +35,8 @@ import ( // SocketVFS2 encapsulates all the state needed to represent a network stack // endpoint in the kernel context. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu } mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &SocketVFS2{ diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..fa9ac9059 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cc7408698..cce0acc33 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "socket_refs.go", package = "unix", prefix = "socketOperations", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketOperations", }, @@ -19,7 +19,7 @@ go_template_instance( out = "socket_vfs2_refs.go", package = "unix", prefix = "socketVFS2", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketVFS2", }, @@ -43,6 +43,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 26c3a51b9..3ebbd28b0 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "queue_refs.go", package = "transport", prefix = "queue", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "queue", }, @@ -44,6 +44,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d6fc03520..b648273a4 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -32,6 +32,8 @@ import ( const initialLimit = 16 * 1024 // A RightsControlMessage is a control message containing FDs. +// +// +stateify savable type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. Clone() RightsControlMessage @@ -336,7 +338,7 @@ type Receiver interface { RecvMaxQueueSize() int64 // Release releases any resources owned by the Receiver. It should be - // called before droping all references to a Receiver. + // called before dropping all references to a Receiver. Release(ctx context.Context) } @@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds c := q.control.Clone() // Don't consume data since we are peeking. - copied, data, _ = vecCopy(data, q.buffer) + copied, _, _ = vecCopy(data, q.buffer) return copied, copied, c, false, q.addr, notify, nil } @@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds return copied, copied, c, cmTruncated, q.addr, notify, nil } +// Release implements Receiver.Release. +func (q *streamQueueReceiver) Release(ctx context.Context) { + q.queueReceiver.Release(ctx) + q.control.Release(ctx) +} + // A ConnectedEndpoint is an Endpoint that can be used to send Messages. type ConnectedEndpoint interface { // Passcred implements Endpoint.Passcred. @@ -619,7 +627,7 @@ type ConnectedEndpoint interface { SendMaxQueueSize() int64 // Release releases any resources owned by the ConnectedEndpoint. It should - // be called before droping all references to a ConnectedEndpoint. + // be called before dropping all references to a ConnectedEndpoint. Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to @@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.PasscredOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index a4a76d0a3..adad485a9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty }, } s.EnableLeakCheck() - return fs.NewFile(ctx, d, flags, &s) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 678355fb9..7a78444dc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) @@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.EnableLeakCheck() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 0ea4aab8b..563d60578 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -12,10 +12,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/time", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 245d2c5cf..167754537 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -19,10 +19,12 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -57,7 +59,7 @@ type SaveOpts struct { } // Save saves the system state. -func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { +func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() @@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { err = ErrStateFile{err} } else { // Save the kernel. - err = k.SaveTo(wc) + err = k.SaveTo(ctx, wc) // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync @@ -108,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n, clocks) + return k.LoadFrom(ctx, r, n, clocks, vfsOpts) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 9c9def7cd..bb1f715e2 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 849a47476..f7135ea46 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return 0, syserror.EINVAL } - r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) + r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize) r.SetFlags(linuxToFlags(flags).Settable()) defer r.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 47dadb800..e383a0a87 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -129,13 +129,27 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getPID(t, id, num) return uintptr(v), nil, err + case linux.IPC_STAT: + arg := args[3].Pointer() + ds, err := ipcStat(t, id) + if err == nil { + _, err = ds.CopyOut(t, arg) + } + + return 0, nil, err + + case linux.GETZCNT: + v, err := getZCnt(t, id, num) + return uintptr(v), nil, err + + case linux.GETNCNT: + v, err := getNCnt(t, id, num) + return uintptr(v), nil, err + case linux.IPC_INFO, linux.SEM_INFO, - linux.IPC_STAT, linux.SEM_STAT, - linux.SEM_STAT_ANY, - linux.GETNCNT, - linux.GETZCNT: + linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -171,6 +185,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP return set.Change(t, creds, owner, perms) } +func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.GetStat(creds) +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) @@ -240,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) { } return int32(tg.ID()), nil } + +func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountZeroWaiters(num, creds) +} + +func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountNegativeWaiters(num, creds) +} diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 46616c961..1c4cdb0dd 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inCh chan struct{} outCh chan struct{} ) + for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) opts.Length -= n @@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inW, _ := waiter.NewChannelEntry(inCh) inFile.EventRegister(&inW, EventMaskRead) defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Need to refresh readiness. + continue } if err = t.Block(inCh); err != nil { break } } - if outFile.Readiness(EventMaskWrite) == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, EventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + // We might be ready to write now. Try again before + // blocking. + continue + } + if err = t.Block(outCh); err != nil { + break } } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 035e2a6b0..9ce4f280a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { // waitForOut waits for dw.outfile to be read. func (dw *dualWaiter) waitForOut(t *kernel.Task) error { - if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if dw.outCh == nil { - dw.outW, dw.outCh = waiter.NewChannelEntry(nil) - dw.outFile.EventRegister(&dw.outW, eventMaskWrite) - // We might be ready now. Try again before blocking. - return nil - } - if err := t.Block(dw.outCh); err != nil { - return err - } - } - return nil + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. See b/172075629, b/170743336. + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready to write now. Try again before blocking. + return nil + } + return t.Block(dw.outCh) } // destroy cleans up resources help by dw. No more calls to wait* can occur diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index c855608db..440c9307c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "file_description_refs.go", package = "vfs", prefix = "FileDescription", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FileDescription", }, @@ -43,7 +43,7 @@ go_template_instance( out = "mount_namespace_refs.go", package = "vfs", prefix = "MountNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "MountNamespace", }, @@ -54,7 +54,7 @@ go_template_instance( out = "filesystem_refs.go", package = "vfs", prefix = "Filesystem", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Filesystem", }, @@ -87,6 +87,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "save_restore.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -99,6 +100,7 @@ go_library( "//pkg/gohacks", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8f36c3e3b..a98aac52b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -74,7 +74,7 @@ type epollInterestKey struct { // +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. - epoll *EpollInstance + epoll *EpollInstance `state:"wait"` // key is the file to which this epollInterest applies. key is immutable. key epollInterestKey diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 183957ad8..546e445aa 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 2d27d9d35..ba6e6ed49 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { return vfs.PrependPathAtVFSRootError{} } - if &d.vfsd == mnt.Root() { + if mnt != nil && &d.vfsd == mnt.Root() { return nil } if d.parent == nil { @@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath d = d.parent } } + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func DebugPathname(d *Dentry) string { + var b fspath.Builder + _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 3f0b8f45b..107171b61 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -65,7 +65,7 @@ type Inotify struct { // queue is used to notify interested parties when the inotify instance // becomes readable or writable. - queue waiter.Queue `state:"nosave"` + queue waiter.Queue // evMu *only* protects the events list. We need a separate lock while // queuing events: using mu may violate lock ordering, since at that point diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 55783d4eb..1ff202f2a 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. package vfs import ( diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 78f115bfa..3ea981ad4 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,6 +107,7 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount if opts.ReadOnly { mnt.setReadOnlyLocked(true) } + refsvfs2.Register(mnt) return mnt } @@ -470,11 +472,12 @@ func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { // tryIncMountedRef does not require that a reference is held on mnt. func (mnt *Mount) tryIncMountedRef() bool { for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + r := atomic.LoadInt64(&mnt.refs) + if r <= 0 { // r < 0 => MSB set => eagerly unmounted return false } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) { + refsvfs2.LogTryIncRef(mnt, r+1) return true } } @@ -484,29 +487,53 @@ func (mnt *Mount) tryIncMountedRef() bool { func (mnt *Mount) IncRef() { // In general, negative values for mnt.refs are valid because the MSB is // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) + r := atomic.AddInt64(&mnt.refs, 1) + refsvfs2.LogIncRef(mnt, r) } // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - var vd VirtualDentry - if mnt.parent() != nil { - mnt.vfs.mountMu.Lock() - mnt.vfs.mounts.seq.BeginWrite() - vd = mnt.vfs.disconnectLocked(mnt) - mnt.vfs.mounts.seq.EndWrite() - mnt.vfs.mountMu.Unlock() - } - if mnt.root != nil { - mnt.root.DecRef(ctx) - } - mnt.fs.DecRef(ctx) - if vd.Ok() { - vd.DecRef(ctx) - } + r := atomic.AddInt64(&mnt.refs, -1) + if r&^math.MinInt64 == 0 { // mask out MSB + refsvfs2.Unregister(mnt) + mnt.destroy(ctx) + } +} + +func (mnt *Mount) destroy(ctx context.Context) { + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + if mnt.root != nil { + mnt.root.DecRef(ctx) } + mnt.fs.DecRef(ctx) + if vd.Ok() { + vd.DecRef(ctx) + } +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (mnt *Mount) RefType() string { + return "vfs.Mount" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (mnt *Mount) LeakMessage() string { + return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (mnt *Mount) LogRefs() bool { + return false } // DecRef decrements mntns' reference count. diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index cb8c56bd3..cb882a983 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) { parent := &Mount{} point := &Dentry{} if m := mt.Lookup(parent, point); m != nil { - t.Errorf("empty mountTable lookup: got %p, wanted nil", m) + t.Errorf("Empty mountTable lookup: got %p, wanted nil", m) } } @@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] mi, ok := ms.Load(k) if !ok { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } m := mi.(*Mount) if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m, _ := ms.Load(k) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b7d122d22..cb48c37a1 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -98,7 +98,6 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 - // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -212,6 +211,26 @@ loop: } } +// Range calls f on each Mount in mt. If f returns false, Range stops iteration +// and returns immediately. +func (mt *mountTable) Range(f func(*Mount) bool) { + tcap := uintptr(1) << (mt.size & mtSizeOrderMask) + slotPtr := mt.slots + last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes)) + for { + slot := (*mountSlot)(slotPtr) + if slot.value != nil { + if !f((*Mount)(slot.value)) { + return + } + } + if slotPtr == last { + return + } + slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes) + } +} + // Insert inserts the given mount into mt. // // Preconditions: mt must not already contain a Mount with the same mount point diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go new file mode 100644 index 000000000..7723ed643 --- /dev/null +++ b/pkg/sentry/vfs/save_restore.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// FilesystemImplSaveRestoreExtension is an optional extension to +// FilesystemImpl. +type FilesystemImplSaveRestoreExtension interface { + // PrepareSave prepares this filesystem for serialization. + PrepareSave(ctx context.Context) error + + // CompleteRestore completes restoration from checkpoint for this + // filesystem after deserialization. + CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error +} + +// PrepareSave prepares all filesystems for serialization. +func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.PrepareSave(ctx); err != nil { + ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) + } + return nil +} + +// CompleteRestore completes restoration from checkpoint for all filesystems +// after deserialization. +func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.CompleteRestore(ctx, *opts); err != nil { + ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) + } + return nil +} + +// CompleteRestoreOptions contains options to +// VirtualFilesystem.CompleteRestore() and +// FilesystemImplSaveRestoreExtension.CompleteRestore(). +type CompleteRestoreOptions struct { + // If ValidateFileSizes is true, filesystem implementations backed by + // remote filesystems should verify that file sizes have not changed + // between checkpoint and restore. + ValidateFileSizes bool + + // If ValidateFileModificationTimestamps is true, filesystem + // implementations backed by remote filesystems should validate that file + // mtimes have not changed between checkpoint and restore. + ValidateFileModificationTimestamps bool +} + +// saveMounts is called by stateify. +func (vfs *VirtualFilesystem) saveMounts() []*Mount { + if atomic.LoadPointer(&vfs.mounts.slots) == nil { + // vfs.Init() was never called. + return nil + } + var mounts []*Mount + vfs.mounts.Range(func(mount *Mount) bool { + mounts = append(mounts, mount) + return true + }) + return mounts +} + +// loadMounts is called by stateify. +func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { + if mounts == nil { + return + } + vfs.mounts.Init() + for _, mount := range mounts { + vfs.mounts.Insert(mount) + } +} + +func (mnt *Mount) afterLoad() { + if atomic.LoadInt64(&mnt.refs) != 0 { + refsvfs2.Register(mnt) + } +} + +// afterLoad is called by stateify. +func (epi *epollInterest) afterLoad() { + // Mark all epollInterests as ready after restore so that the next call to + // EpollInstance.ReadEvents() rechecks their readiness. + epi.Callback(nil) +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 38d2701d2..48d6252f7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -71,7 +71,7 @@ type VirtualFilesystem struct { // points. // // mounts is analogous to Linux's mount_hashtable. - mounts mountTable + mounts mountTable `state:".([]*Mount)"` // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. @@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre // SyncAllFilesystems has the semantics of Linux's sync(2). func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + var retErr error + for fs := range vfs.getFilesystems() { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef(ctx) + } + return retErr +} + +func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} { fss := make(map[*Filesystem]struct{}) vfs.filesystemsMu.Lock() + defer vfs.filesystemsMu.Unlock() for fs := range vfs.filesystems { if !fs.TryIncRef() { continue } fss[fs] = struct{}{} } - vfs.filesystemsMu.Unlock() - var retErr error - for fs := range fss { - if err := fs.impl.Sync(ctx); err != nil && retErr == nil { - retErr = err - } - fs.DecRef(ctx) - } - return retErr + return fss } // MkdirAllAt recursively creates non-existent directories on the given path diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD index f08599ebd..cb0001852 100644 --- a/pkg/shim/runsc/BUILD +++ b/pkg/shim/runsc/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "@com_github_containerd_containerd//log:go_default_library", "@com_github_containerd_go_runc//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index c5cf68efa..e7c9640ba 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -28,10 +28,12 @@ import ( "syscall" "time" + "github.com/containerd/containerd/log" runc "github.com/containerd/go-runc" specs "github.com/opencontainers/runtime-spec/specs-go" ) +// Monitor is the default process monitor to be used by runsc. var Monitor runc.ProcessMonitor = runc.Monitor // DefaultCommand is the default command for Runsc. @@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro return &c, nil } +// CreateOpts is a set of options to Runsc.Create(). type CreateOpts struct { runc.IO ConsoleSocket runc.ConsoleSocket @@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) { return res.ExitStatus, nil } +// ExecOpts is a set of options to runsc.Exec(). type ExecOpts struct { runc.IO PidFile string @@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts return Monitor.Wait(cmd, ec) } +// DeleteOpts is a set of options to runsc.Delete(). type DeleteOpts struct { Force bool } @@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { if err := json.NewDecoder(rd).Decode(&e); err != nil { return nil, err } + log.L.Debugf("Stats returned: %+v", e.Stats) + if e.Type != "stats" { + return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type) + } + if e.Stats == nil { + return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`) + } return e.Stats, nil } diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 089b3bbef..92c51879b 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "pending_list", - out = "pending_list.go", - package = "state", - prefix = "pending", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "pendingMapper", - "Linker": "*pendingEntry", - }, -) - -go_template_instance( name = "deferred_list", out = "deferred_list.go", package = "state", @@ -83,7 +70,6 @@ go_library( "deferred_list.go", "encode.go", "encode_unsafe.go", - "pending_list.go", "state.go", "state_norace.go", "state_race.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 89467ca8e..e519ddeca 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -21,6 +21,7 @@ import ( "math" "reflect" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be -// directly assignable. See unsafePointerTo. +// directly assignable. See decode_unsafe.go. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { @@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) - if v.IsValid() { - obj.Set(unsafePointerTo(v)) - } + obj.Set(reflectValueRWAddr(v)) case wire.Bool: obj.SetBool(bool(x)) case wire.Int: @@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) - obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: @@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) { ds.pending.PushBack(rootOds) // Read the number of objects. - lastID, object, err := ReadHeader(ds.r) + numObjects, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } @@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) { var ( encoded wire.Object ods *objectDecodeState - id = objectID(1) + id objectID tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // - // Note that the structure of this decoding loop should match - // the raw decoding loop in printer.go. - for id <= objectID(lastID) { - // Unmarshal the object. + // Note that the structure of this decoding loop should match the raw + // decoding loop in state/pretty/pretty.printer.printStream(). + for i := uint64(0); i < numObjects; { + // Unmarshal either a type object or object ID. encoded = wire.Load(ds.r) - - // Is this a type object? Handle inline. - if wt, ok := encoded.(*wire.Type); ok { - ds.types.Register(wt) + switch we := encoded.(type) { + case *wire.Type: + ds.types.Register(we) tid++ encoded = nil continue + case wire.Uint: + id = objectID(we) + i++ + // Unmarshal and resolve the actual object. + encoded = wire.Load(ds.r) + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } + // For error handling. + ods = nil + encoded = nil + default: + Failf("wanted type or object ID, got %#v", encoded) } - - // Actually resolve the object. - ods = ds.lookup(id) - if ods != nil { - // Decode the object. - ds.decodeObject(ods, ods.obj, encoded) - } else { - // If an object hasn't had interest registered - // previously or isn't yet valid, we deferred - // decoding until interest is registered. - ds.deferred[id] = encoded - } - - // For error handling. - ods = nil - encoded = nil - id++ } }); err != nil { // Include as much information as we can, taking into account @@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) { if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { - Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) + Failf("error decoding from %#v: %w", encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. + numDeferred := 0 for id, encoded := range ds.deferred { - // Shoud never happen, the graph was bogus. - Failf("still have deferred objects: one is ID %d, %#v", id, encoded) + numDeferred++ + if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 { + typ := ds.types.LookupType(typeID(s.TypeID)) + log.Warningf("unused deferred object: ID %d, type %v", id, typ) + } else { + log.Warningf("unused deferred object: ID %d, %#v", id, encoded) + } + } + if numDeferred != 0 { + Failf("still had %d deferred objects", numDeferred) } // Scan and fire all callbacks. We iterate over the list of incomplete diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go index d048f61a1..f1208e2a2 100644 --- a/pkg/state/decode_unsafe.go +++ b/pkg/state/decode_unsafe.go @@ -15,13 +15,62 @@ package state import ( + "fmt" "reflect" + "runtime" "unsafe" ) -// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on -// values representing unexported fields. This bypasses visibility, but not -// type safety. -func unsafePointerTo(obj reflect.Value) reflect.Value { +// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned +// reflect.Value is usable in assignments even if obj was obtained by the use +// of unexported struct fields. +// +// Preconditions: obj.CanAddr(). +func reflectValueRWAddr(obj reflect.Value) reflect.Value { return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) } + +// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the +// returned reflect.Value is usable in assignments even if obj was obtained by +// the use of unexported struct fields. +// +// Preconditions: +// * arr.Kind() == reflect.Array. +// * i, j, k >= 0. +// * i <= j <= k <= arr.Len(). +func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value { + if arr.Kind() != reflect.Array { + panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array)) + } + if i < 0 || j < 0 || k < 0 { + panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k)) + } + if i > j { + panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j)) + } + if j > k { + panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k)) + } + if k > arr.Len() { + panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len())) + } + + sliceTyp := reflect.SliceOf(arr.Type().Elem()) + if i == arr.Len() { + // By precondition, i == j == k == arr.Len(). + return reflect.MakeSlice(sliceTyp, 0, 0) + } + slh := reflect.SliceHeader{ + // reflect.Value.CanAddr() == false for arrays, so we need to get the + // address from the first element of the array. + Data: arr.Index(i).UnsafeAddr(), + Len: j - i, + Cap: k - i, + } + slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem() + // Before slobj is constructed, arr holds the only pointer-typed pointer to + // the array since reflect.SliceHeader.Data is a uintptr, so arr must be + // kept alive. + runtime.KeepAlive(arr) + return slobj +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 92fcad4e9..560e7c2a3 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -17,13 +17,14 @@ package state import ( "context" "reflect" + "sort" "gvisor.dev/gvisor/pkg/state/wire" ) // objectEncodeState the type and identity of an object occupying a memory // address range. This is the value type for addrSet, and the intrusive entry -// for the pending and deferred lists. +// for the deferred list. type objectEncodeState struct { // id is the assigned ID for this object. id objectID @@ -47,7 +48,6 @@ type objectEncodeState struct { // references may be updated directly and automatically. refs []*wire.Ref - pendingEntry deferredEntry } @@ -93,9 +93,15 @@ type encodeState struct { // serialized. pendingTypes []wire.Type - // pending is the list of objects to be serialized. Serialization does + // pending maps object IDs to objects to be serialized. Serialization does // not actually occur until the full object graph is computed. - pending pendingList + pending map[objectID]*objectEncodeState + + // encodedStructs maps reflect.Values representing structs to previous + // encodings of those structs. This is necessary to avoid duplicate calls + // to SaverLoader.StateSave() that may result in multiple calls to + // Sink.SaveValue() for a given field, resulting in object duplication. + encodedStructs map[reflect.Value]*wire.Struct // stats tracks time data. stats Stats @@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { // depending on this value knows there's nothing there. return } - if seg, _ := es.values.Find(addr); seg.Ok() { + seg, gap := es.values.Find(addr) + if seg.Ok() { // Ensure the map types match. existing := seg.Value() if existing.obj.Type() != obj.Type() { @@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { } // Record the map. + r := addrRange{addr, addr + 1} oes := &objectEncodeState{ id: es.nextID(), obj: obj, how: encodeMapAsValue, } - es.values.Add(addrRange{addr, addr + 1}, oes) - es.pending.PushBack(oes) + // Use Insert instead of InsertWithoutMergingUnchecked when race + // detection is enabled to get additional sanity-checking from Merge. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes es.deferred.PushBack(oes) // See above: no ref recording. @@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { obj: obj, } es.zeroValues[typ] = oes - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) } @@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { size = 1 // See above. } - // Calculate the container. end := addr + size r := addrRange{addr, end} - if seg, _ := es.values.Find(addr); seg.Ok() { + seg := es.values.LowerBoundSegment(addr) + var ( + oes *objectEncodeState + gap addrGapIterator + ) + + // Does at least one previously-registered object overlap this one? + if seg.Ok() && seg.Start() < end { existing := seg.Value() - switch { - case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): - // The object is a perfect match. Happy path. Avoid the - // traversal and just return directly. We don't need to - // encode the type information or any dots here. + + if seg.Range() == r && typ == existing.obj.Type() { + // This exact object is already registered. Avoid the traversal and + // just return directly. We don't need to encode the type + // information or any dots here. ref.Root = wire.Uint(existing.id) existing.refs = append(existing.refs, ref) return + } - case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): - // The previously registered object is larger than - // this, no need to update. But we expect some - // traversal below. + if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) { + // This object is contained within a previously-registered object. + // Perform traversal from the container to the new object. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr) + ref.Type = es.findType(existing.obj.Type()) + existing.refs = append(existing.refs, ref) + return + } - case seg.Start() == addr && seg.End() == end: - if !isSameSizeParent(obj, existing.obj.Type()) { - break // Needs traversal. + // This object contains one or more previously-registered objects. + // Remove them and update existing references to use the new one. + oes := &objectEncodeState{ + // Reuse the root ID of the first contained element. + id: existing.id, + obj: obj, + } + type elementEncodeState struct { + addr uintptr + typ reflect.Type + refs []*wire.Ref + } + var ( + elems []elementEncodeState + gap addrGapIterator + ) + for { + // Each contained object should be completely contained within + // this one. + if raceEnabled && !r.IsSupersetOf(seg.Range()) { + Failf("containing object %#v does not contain existing object %#v", obj, existing.obj) } - fallthrough // Needs update. - - case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): - // Update the object and redo the encoding. - old := existing.obj - existing.obj = obj + elems = append(elems, elementEncodeState{ + addr: seg.Start(), + typ: existing.obj.Type(), + refs: existing.refs, + }) + delete(es.pending, existing.id) es.deferred.Remove(existing) - es.deferred.PushBack(existing) - - // The previously registered object is superseded by - // this new object. We are guaranteed to not have any - // mergeable neighbours in this segment set. - if !raceEnabled { - seg.SetRangeUnchecked(r) - } else { - // Add extra paranoid. This will be statically - // removed at compile time unless a race build. - es.values.Remove(seg) - es.values.Add(r, existing) - seg = es.values.LowerBoundSegment(addr) + gap = es.values.Remove(seg) + seg = gap.NextSegment() + if !seg.Ok() || seg.Start() >= end { + break } - - // Compute the traversal required & update references. - dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) - wt := es.findType(obj.Type()) - for _, ref := range existing.refs { + existing = seg.Value() + } + wt := es.findType(typ) + for _, elem := range elems { + dots := traverse(typ, elem.typ, addr, elem.addr) + for _, ref := range elem.refs { + ref.Root = wire.Uint(oes.id) ref.Dots = append(ref.Dots, dots...) ref.Type = wt } - default: - // There is a non-sensical overlap. - Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + oes.refs = append(oes.refs, elem.refs...) } - - // Compute the new reference, record and return it. - ref.Root = wire.Uint(existing.id) - ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) - ref.Type = es.findType(obj.Type()) - existing.refs = append(existing.refs, ref) + // Finally register the new containing object. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) return } - // The only remaining case is a pointer value that doesn't overlap with - // any registered addresses. Create a new entry for it, and start - // tracking the first reference we just created. - oes := &objectEncodeState{ + // No existing object overlaps this one. Register a new object. + oes = &objectEncodeState{ id: es.nextID(), obj: obj, } + if seg.Ok() { + gap = seg.PrevGap() + } else { + gap = es.values.LastGap() + } if !raceEnabled { - es.values.AddWithoutMerging(r, oes) + es.values.InsertWithoutMergingUnchecked(gap, r, oes) } else { - // Merges should never happen. This is just enabled extra - // sanity checks because the Merge function below will panic. - es.values.Add(r, oes) + es.values.Insert(gap, r, oes) } - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) ref.Root = wire.Uint(oes.id) oes.refs = append(oes.refs, ref) @@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) { // encodeStruct encodes a composite object. func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + if s, ok := es.encodedStructs[obj]; ok { + *dest = s + return + } + s := &wire.Struct{} + *dest = s + es.encodedStructs[obj] = s + // Ensure that the obj is addressable. There are two cases when it is // not. First, is when this is dispatched via SaveValue. Second, when // this is a map key as a struct. Either way, we need to make a copy to @@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { obj = localObj.Elem() } - // Prepare the value. - s := &wire.Struct{} - *dest = s - // Look the type up in the database. te, ok := es.types.Lookup(obj.Type()) if te == nil { @@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) { Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) } - // Check that items are pending. - if es.pending.Front() == nil { + // Check that we have objects to serialize. + if len(es.pending) == 0 { Failf("pending is empty?") } - // Write the header with the number of objects. Note that there is no - // way that es.lastID could conflict with objectID, which would - // indicate that an impossibly large encoding. - if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + // Write the header with the number of objects. + if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil { Failf("error writing header: %w", err) } // Serialize all pending types and pending objects. Note that we don't // bother removing from this list as we walk it because that just // wastes time. It will not change after this point. - var id objectID if err := safely(func() { for _, wt := range es.pendingTypes { // Encode the type. wire.Save(es.w, &wt) } - for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { - id++ // First object is 1. - if oes.id != id { - Failf("expected id %d, got %d", id, oes.id) - } - - // Marshall the object. + // Emit objects in ID order. + ids := make([]objectID, 0, len(es.pending)) + for id := range es.pending { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + for _, id := range ids { + // Encode the id. + wire.Save(es.w, wire.Uint(id)) + // Marshal the object. + oes := es.pending[id] wire.Save(es.w, oes.encoded) } }); err != nil { // Include the object and the error. Failf("error serializing object %#v: %w", oes.encoded, err) } - - // Check what we wrote. - if id != es.lastID { - Failf("expected %d objects, wrote %d", es.lastID, id) - } } // objectFlag indicates that the length is a # of objects, rather than a raw @@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error { }) } -// pendingMapper is for the pending list. -type pendingMapper struct{} - -func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // deferredMapper is for the deferred list. type deferredMapper struct{} diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index 887f453a9..c6e8bb31d 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { buf.WriteString(typ) buf.WriteString(")(") buf.WriteString(baseRef) + buf.WriteString(")") for _, component := range x.Dots { switch v := component.(type) { case *wire.FieldName: @@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) } } - buf.WriteString(")") fullRef = buf.String() } if p.html { @@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // Note that this loop must match the general structure of the // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. + type objectAndID struct { + id uint64 + obj wire.Object + } var ( tid uint64 = 1 - objects []wire.Object + objects []objectAndID ) - for oid := uint64(1); oid <= length; { - // Unmarshal the object. + for i := uint64(0); i < length; { + // Unmarshal either a type object or object ID. encoded := wire.Load(r) - - // Is this a type? - if typ, ok := encoded.(*wire.Type); ok { + switch we := encoded.(type) { + case *wire.Type: str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - p.typeSpecs[tag] = typ + p.typeSpecs[tag] = we if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) @@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { return err } tid++ - continue + case wire.Uint: + // Unmarshal the actual object. + objects = append(objects, objectAndID{ + id: uint64(we), + obj: wire.Load(r), + }) + i++ + default: + return fmt.Errorf("wanted type or object ID, got %#v", encoded) } - - // Otherwise, it is a node. - objects = append(objects, encoded) - oid++ } - for i, encoded := range objects { - // oid starts at 1. - oid := i + 1 + for _, objAndID := range objects { // Format the node. - str, _ := p.format(graph, 0, encoded) - tag := fmt.Sprintf("g%dr%d", graph, oid) + str, _ := p.format(graph, 0, objAndID.obj) + tag := fmt.Sprintf("g%dr%d", graph, objAndID.id) if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) diff --git a/pkg/state/state.go b/pkg/state/state.go index acb629969..6b8540f03 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error { func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. es := encodeState{ - ctx: ctx, - w: w, - types: makeTypeEncodeDatabase(), - zeroValues: make(map[reflect.Type]*objectEncodeState), + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), + pending: make(map[objectID]*objectEncodeState), + encodedStructs: make(map[reflect.Value]*wire.Struct), } // Perform the encoding. diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go index bd2c2b399..69143d194 100644 --- a/pkg/state/tests/struct.go +++ b/pkg/state/tests/struct.go @@ -54,12 +54,47 @@ type outerArray struct { } // +stateify savable +type outerSlice struct { + inner []inner +} + +// +stateify savable type inner struct { v int64 } // +stateify savable +type outerFieldValue struct { + inner innerFieldValue +} + +// +stateify savable +type innerFieldValue struct { + v int64 `state:".(*savedFieldValue)"` +} + +// +stateify savable +type savedFieldValue struct { + v int64 +} + +func (ifv *innerFieldValue) saveV() *savedFieldValue { + return &savedFieldValue{ifv.v} +} + +func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { + ifv.v = sfv.v +} + +// +stateify savable type system struct { v1 interface{} v2 interface{} } + +// +stateify savable +type system3 struct { + v1 interface{} + v2 interface{} + v3 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index de9d17aa7..c91c2c032 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -15,6 +15,7 @@ package tests import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/state" @@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) { } func TestEmbeddedPointers(t *testing.T) { - var ( - ofs outerSame - of1 outerFieldFirst - of2 outerFieldSecond - oa outerArray - ) + // Give each int64 a random value to prevent Go from using + // runtime.staticuint64s, which confounds tests for struct duplication. + magic := func() int64 { + for { + n := rand.Int63() + if n < 0 || n > 255 { + return n + } + } + } + + ofs := outerSame{inner{magic()}} + of1 := outerFieldFirst{inner{magic()}, magic()} + of2 := outerFieldSecond{magic(), inner{magic()}} + oa := outerArray{[2]inner{{magic()}, {magic()}}} + osl := outerSlice{oa.inner[:]} + ofv := outerFieldValue{innerFieldValue{magic()}} runTestCases(t, false, "embedded-pointers", []interface{}{ system{&ofs, &ofs.inner}, @@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) { system{&oa, &oa.inner[1]}, system{&oa.inner[0], &oa}, system{&oa.inner[1], &oa}, + system3{&oa, &oa.inner[0], &oa.inner[1]}, + system3{&oa, &oa.inner[1], &oa.inner[0]}, + system3{&oa.inner[0], &oa, &oa.inner[1]}, + system3{&oa.inner[1], &oa, &oa.inner[0]}, + system3{&oa.inner[0], &oa.inner[1], &oa}, + system3{&oa.inner[1], &oa.inner[0], &oa}, + system{&oa, &osl}, + system{&osl, &oa}, + system{&ofv, &ofv.inner}, + system{&ofv.inner, &ofv}, }) } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 12b061def..b196324c7 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -97,6 +97,9 @@ type testConnection struct { func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, err + } entry, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&entry, waiter.EventOut) @@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } // Give c.Read() a chance to block before closing the connection. @@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } c.SetDeadline(time.Now().Add(time.Minute)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 6f81b0164..530f2ae2f 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -205,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker { if !ok { t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) } - options := ip.Options() + options := []byte(ip.Options()) // cmp.Diff does not consider nil slices equal to empty slices, but we do. if len(want) == 0 && len(options) == 0 { return @@ -859,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker { } } +// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. +func ICMPv4Pointer(want uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Pointer(); got != want { + t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) + } + } +} + // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. // This assumes that the payload exactly makes up the rest of the slice. func ICMPv4Checksum() TransportChecker { @@ -953,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { } } +// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific +// field. +func ICMPv6TypeSpecific(want uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + if got := icmpv6.TypeSpecific(); got != want { + t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) + } + } +} + +// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. +func ICMPv6Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + payload := icmpv6.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 504408878..2f13dea6a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -99,7 +99,8 @@ const ( // ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 ) // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. @@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 4c6e4be64..961b77628 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,7 +39,6 @@ import ( // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Options | Padding | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// const ( versIHL = 0 tos = 1 @@ -93,7 +93,7 @@ type IPv4Fields struct { DstAddr tcpip.Address } -// IPv4 represents an ipv4 header stored in a byte array. +// IPv4 is an IPv4 header. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other @@ -106,10 +106,13 @@ const ( IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. // // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 @@ -130,7 +133,7 @@ const ( // IPv4ProtocolNumber is IPv4's network protocol number. IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 - // IPv4Version is the version of the ipv4 protocol. + // IPv4Version is the version of the IPv4 protocol. IPv4Version = 4 // IPv4AllSystems is the all systems IPv4 multicast address as per @@ -148,6 +151,13 @@ const ( // packet that every IPv4 capable host must be able to // process/reassemble. IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 ) // Flags that may be set in an IPv4 packet. @@ -191,14 +201,13 @@ func IPVersion(b []byte) int { // Internet Header Length is the length of the internet header in 32 // bit words, and thus points to the beginning of the data. Note that // the minimum value for a correct header is 5. -// const ( ipVersionShift = 4 ipIHLMask = 0x0f IPv4IHLStride = 4 ) -// HeaderLength returns the value of the "header length" field of the ipv4 +// HeaderLength returns the value of the "header length" field of the IPv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & ipIHLMask) * IPv4IHLStride @@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) { b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } -// ID returns the value of the identifier field of the ipv4 header. +// ID returns the value of the identifier field of the IPv4 header. func (b IPv4) ID() uint16 { return binary.BigEndian.Uint16(b[id:]) } -// Protocol returns the value of the protocol field of the ipv4 header. +// Protocol returns the value of the protocol field of the IPv4 header. func (b IPv4) Protocol() uint8 { return b[protocol] } -// Flags returns the "flags" field of the ipv4 header. +// Flags returns the "flags" field of the IPv4 header. func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } @@ -232,41 +241,44 @@ func (b IPv4) More() bool { return b.Flags()&IPv4FlagMoreFragments != 0 } -// TTL returns the "TTL" field of the ipv4 header. +// TTL returns the "TTL" field of the IPv4 header. func (b IPv4) TTL() uint8 { return b[ttl] } -// FragmentOffset returns the "fragment offset" field of the ipv4 header. +// FragmentOffset returns the "fragment offset" field of the IPv4 header. func (b IPv4) FragmentOffset() uint16 { return binary.BigEndian.Uint16(b[flagsFO:]) << 3 } -// TotalLength returns the "total length" field of the ipv4 header. +// TotalLength returns the "total length" field of the IPv4 header. func (b IPv4) TotalLength() uint16 { return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } -// Checksum returns the checksum field of the ipv4 header. +// Checksum returns the checksum field of the IPv4 header. func (b IPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[checksum:]) } -// SourceAddress returns the "source address" field of the ipv4 header. +// SourceAddress returns the "source address" field of the IPv4 header. func (b IPv4) SourceAddress() tcpip.Address { return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// DestinationAddress returns the "destination address" field of the IPv4 // header. func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// Options returns a a buffer holding the options. -func (b IPv4) Options() []byte { +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - return b[options:hdrLen:hdrLen] + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } -// PayloadLength returns the length of the payload portion of the ipv4 packet. +// PayloadLength returns the length of the payload portion of the IPv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) } -// TOS returns the "type of service" field of the ipv4 header. +// TOS returns the "type of service" field of the IPv4 header. func (b IPv4) TOS() (uint8, uint32) { return b[tos], 0 } -// SetTOS sets the "type of service" field of the ipv4 header. +// SetTOS sets the "type of service" field of the IPv4 header. func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } @@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) { b[ttl] = v } -// SetTotalLength sets the "total length" field of the ipv4 header. +// SetTotalLength sets the "total length" field of the IPv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. +// SetChecksum sets the checksum field of the IPv4 header. func (b IPv4) SetChecksum(v uint16) { binary.BigEndian.PutUint16(b[checksum:], v) } // SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. +// IPv4 header. func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) @@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) { binary.BigEndian.PutUint16(b[id:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. +// SetSourceAddress sets the "source address" field of the IPv4 header. func (b IPv4) SetSourceAddress(addr tcpip.Address) { copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) } -// SetDestinationAddress sets the "destination address" field of the ipv4 +// SetDestinationAddress sets the "destination address" field of the IPv4 // header. func (b IPv4) SetDestinationAddress(addr tcpip.Address) { copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) } -// CalculateChecksum calculates the checksum of the ipv4 header. +// CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { return Checksum(b[:b.HeaderLength()], 0) } -// Encode encodes all the fields of the ipv4 header. +// Encode encodes all the fields of the IPv4 header. func (b IPv4) Encode(i *IPv4Fields) { b.SetHeaderLength(i.IHL) b[tos] = i.TOS @@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) { copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) } -// EncodePartial updates the total length and checksum fields of ipv4 header, +// EncodePartial updates the total length and checksum fields of IPv4 header, // taking in the partial checksum, which is the checksum of the header without // the total length and checksum fields. It is useful in cases when similar // packets are produced. @@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool { } return addr[0] == 0x7f } + +// ========================= Options ========================== + +// An IPv4OptionType can hold the valuse for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// Potential errors when parsing generic IP options. +var ( + ErrIPv4OptZeroLength = errors.New("zero length IP option") + ErrIPv4OptDuplicate = errors.New("duplicate IP option") + ErrIPv4OptInvalid = errors.New("invalid IP option") + ErrIPv4OptMalformed = errors.New("malformed IP option") + ErrIPv4OptionTruncated = errors.New("truncated IP option") + ErrIPv4OptionAddress = errors.New("bad IP option address") +) + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return IPv4Options(i.newOptions[i.writePoint:]) +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// - An error which is non-nil if an error condition was encountered. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, true, ErrIPv4OptMalformed + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen == 0 { + i.ErrCursor++ + return nil, true, ErrIPv4OptZeroLength + } + + if optLen == 1 { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + + if optLen > uint8(len(i.options)) { + i.ErrCursor++ + return nil, true, ErrIPv4OptionTruncated + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + const millisecondsPerDay = 24 * 3600 * 1000 + const nanoPerMilli = 1000000 + return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.Address(slot[:IPv4AddressSize]) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index c5d8a3456..4e7e5f76a 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -101,8 +101,10 @@ const ( // The address is ff02::2. IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, - // section 5. + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. IPv6MinimumMTU = 1280 // IPv6Loopback is the IPv6 Loopback address. @@ -373,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback +// address. +func IsV6LoopbackAddress(addr tcpip.Address) bool { + return addr == IPv6Loopback +} + // IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 // link-local multicast address. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 560477926..b3e8c4b92 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // // We don't clone the original packet buffer so that the new packet buffer // does not have any of its headers set. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) + // + // We trim the link headers from the cloned buffer as the sniffer doesn't + // handle link headers. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + vv.TrimFront(len(pkt.LinkHeader().View())) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) switch protocol { case header.IPv4ProtocolNumber: if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 0243424f6..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "tun_endpoint_refs.go", package = "tun", prefix = "tunEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tunEndpoint", }, @@ -28,6 +28,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..cda6328a2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) - // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.EnableLeakCheck() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index b40dde96b..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -30,5 +30,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..33a4a0720 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,7 @@ package arp import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -121,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } @@ -144,34 +145,43 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -199,6 +209,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -227,26 +238,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - NetProto: ProtocolNumber, - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) - - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -286,6 +315,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // Note, to make sure that the ARP endpoint receives ARP packets, the "arp" // address must be added to every NIC that should respond to ARP requests. See // ProtocolAddress for more details. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 626af975a..087ee9c66 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,11 @@ func (t eventType) String() string { type eventInfo struct { eventType eventType nicID tcpip.NICID - addr tcpip.Address - linkAddr tcpip.LinkAddress - state stack.NeighborState + entry stack.NeighborEntry } func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) } // arpDispatcher implements NUDDispatcher to validate the dispatching of @@ -96,35 +95,29 @@ type arpDispatcher struct { var _ stack.NUDDispatcher = (*arpDispatcher)(nil) -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryChanged, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryRemoved, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } @@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { return fmt.Errorf("got invalid event (-got +want):\n%s", diff) } case <-ctx.Done(): @@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { wantEvent := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: test.senderAddr, - linkAddr: tcpip.LinkAddress(test.senderLinkAddr), - state: stack.Stale, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, } if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { t.Fatal(err) @@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) } - if got, want := neigh.LocalAddr, stackAddr; got != want { - t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) - } if got, want := neigh.State, stack.Stale; got != want { t.Errorf("got neighbor State = %s, want = %s", got, want) } @@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + tests := []struct { name string + nicAddr tcpip.Address + localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", remoteLinkAddr: remoteLinkAddr, - expectLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, }, { - name: "Multicast", + name: "Multicast with no local address available", remoteLinkAddr: "", - expectLinkAddr: header.EthernetBroadcastAddress, + expectedErr: tcpip.ErrNetworkUnreachable, }, } for _, test := range tests { - p := arp.NewProtocol(nil) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) - } + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) - } + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) } } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index ed502a473..936601287 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. +// +// releaseCB is a callback that will run when the fragment reassembly of a +// packet is complete or cancelled. releaseCB take a a boolean argument which is +// true iff the reassembly is cancelled due to timeout. releaseCB should be +// passed only with the first fragment of a packet. If more than one releaseCB +// are passed for the same packet, only the first releaseCB will be saved for +// the packet and the succeeding ones will be dropped by running them +// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -171,6 +179,12 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } + if releaseCB != nil { + if !r.setCallback(releaseCB) { + // We got a duplicate callback. Release it immediately. + releaseCB(false /* timedOut */) + } + } f.mu.Unlock() res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) @@ -178,14 +192,14 @@ func (f *Fragmentation) Process( // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -195,14 +209,14 @@ func (f *Fragmentation) Process( if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + r.release(timedOut) // releaseCB may run. } // releaseReassemblersLocked releases already-expired reassemblers, then @@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() { break } // If the oldest reassembler has already expired, release it. - f.release(r) + f.release(r, true /* timedOut*/) } } // PacketFragmenter is the book-keeping struct for packet fragmentation. type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - innerMTU int - fragmentCount int - currentFragment int - fragmentOffset int + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int } // MakePacketFragmenter prepares the struct needed for packet fragmentation. // // pkt is the packet to be fragmented. // -// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can // have. // // reserve is the number of bytes that should be reserved for the headers in // each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be // repeated in each fragment. However we do not currently support any header // of that kind yet, so the following computation is valid for both IPv4 and @@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) fragmentableData.Append(pkt.Data) - fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - innerMTU: innerMTU, - fragmentCount: fragmentCount, + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), } } @@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index d3c7d7f92..5dcd10730 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) @@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) { for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) { func TestMemoryLimits(t *testing.T) { f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) { func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) got := f.size want := 1 @@ -377,7 +377,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) { tests := []struct { name string - innerMTU int + fragmentPayloadLen uint32 transportHeaderLen int payloadSize int wantFragments []fragmentInfo }{ { name: "Packet exactly fits in MTU", - innerMTU: 1280, + fragmentPayloadLen: 1280, transportHeaderLen: 0, payloadSize: 1280, wantFragments: []fragmentInfo{ @@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet exactly does not fit in MTU", - innerMTU: 1000, + fragmentPayloadLen: 1000, transportHeaderLen: 0, payloadSize: 1001, wantFragments: []fragmentInfo{ @@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a transport header", - innerMTU: 560, + fragmentPayloadLen: 560, transportHeaderLen: 40, payloadSize: 560, wantFragments: []fragmentInfo{ @@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a huge transport header", - innerMTU: 500, + fragmentPayloadLen: 500, transportHeaderLen: 1300, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) { originalPayload.AppendView(pkt.TransportHeader().View()) originalPayload.Append(pkt.Data) var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { fragPkt, offset, copied, more := pf.BuildNextFragment() wantFragment := test.wantFragments[i] @@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) { if more != wantFragment.more { t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) } - if got := fragPkt.Size(); got > test.innerMTU { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) } if got := fragPkt.AvailableHeaderBytes(); got != reserve { t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) @@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) { }) } } + +func TestReleaseCallback(t *testing.T) { + const ( + proto = 99 + ) + + var result int + var callbackReasonIsTimeout bool + cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + + tests := []struct { + name string + callbacks []func(bool) + timeout bool + wantResult int + wantCallbackReasonIsTimeout bool + }{ + { + name: "callback runs on release", + callbacks: []func(bool){cb1}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "first callback is nil", + callbacks: []func(bool){nil, cb2}, + timeout: false, + wantResult: 2, + wantCallbackReasonIsTimeout: false, + }, + { + name: "two callbacks - first one is set", + callbacks: []func(bool){cb1, cb2}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "callback runs on timeout", + callbacks: []func(bool){cb1}, + timeout: true, + wantResult: 1, + wantCallbackReasonIsTimeout: true, + }, + { + name: "no callbacks", + callbacks: []func(bool){nil}, + timeout: false, + wantResult: 0, + wantCallbackReasonIsTimeout: false, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result = 0 + callbackReasonIsTimeout = false + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + + for i, cb := range test.callbacks { + _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) + if err != nil { + t.Errorf("f.Process error = %s", err) + } + } + + r, ok := f.reassemblers[id] + if !ok { + t.Fatalf("Reassemberr not found") + } + f.release(r, test.timeout) + + if result != test.wantResult { + t.Errorf("got result = %d, want = %d", result, test.wantResult) + } + if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { + t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9bb051a30..c0cc0bde0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -41,6 +41,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 + callback func(bool) } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } + +func (r *reassembler) setCallback(c func(bool)) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.callback != nil { + return false + } + r.callback = c + return true +} + +func (r *reassembler) release(timedOut bool) { + r.mu.Lock() + callback := r.callback + r.callback = nil + r.mu.Unlock() + + if callback != nil { + callback(timedOut) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..fa2a70dc8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) { } } } + +func TestSetCallback(t *testing.T) { + result := 0 + reasonTimeout := false + + cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } + + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + if !r.setCallback(cb1) { + t.Errorf("setCallback failed") + } + if r.setCallback(cb2) { + t.Errorf("setCallback should fail if one is already set") + } + r.release(true) + if result != 1 { + t.Errorf("got result = %d, want = 1", result) + } + if !reasonTimeout { + t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f20b94d97..8873bd91f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { + netHdr := pkt.Network() + t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -304,6 +305,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -604,7 +609,8 @@ func TestIPv4Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -690,6 +696,10 @@ func TestIPv4ReceiveControl(t *testing.T) { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. nic.testObject.protocol = 10 @@ -699,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -780,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } @@ -792,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -892,7 +906,8 @@ func TestIPv6Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -1009,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -1063,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum tcpip.NetworkProtocolNumber nicAddr tcpip.Address remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.View + pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) expectedErr *tcpip.Error }{ @@ -1073,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1087,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1115,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1129,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1139,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1148,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1158,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1167,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1195,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) @@ -1213,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) } - return hdr.View() + return hdr.View().ToVectorisedView() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with options and data across views", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + vv := buffer.View(ip).ToVectorisedView() + vv.AppendView(ipv4Options) + vv.AppendView(data) + return vv }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1243,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1256,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1283,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1299,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1326,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1334,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1361,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1369,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1413,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { defer r.Release() if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + Data: test.pktGen(t, subTest.srcAddr), })); err != test.expectedErr { t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7fc12e229..6252614ec 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -29,6 +29,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..9b5e37fee 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -23,10 +24,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { @@ -41,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { return } @@ -57,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { - stats := r.Stats() +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() received := stats.ICMP.V4PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a @@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if len(iph) > header.IPv4MinimumSize { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - headerChecksum := h.Checksum() - h.SetChecksum(0) - calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(headerChecksum) - if calculatedChecksum != headerChecksum { - // It's possible that a raw socket still expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4PacketsSent + if !e.protocol.stack.AllowICMPMessage() { + sent.RateLimited.Increment() return } @@ -98,19 +144,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() - replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast + + // It's possible that a raw socket expects to receive this. + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + pkt = nil - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). - localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr := ipHdr.DestinationAddress() + if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -139,7 +193,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // The fields we need to alter. // // We need to produce the entire packet in the data segment in order to - // use WriteHeaderIncludedPacket(). + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. replyIPHdr.SetSourceAddress(r.LocalAddress) replyIPHdr.SetDestinationAddress(r.RemoteAddress) replyIPHdr.SetTTL(r.DefaultTTL()) @@ -157,8 +212,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { }) replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // The checksum will be calculated so we don't need to do it here. - sent := stats.ICMP.V4PacketsSent if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return @@ -168,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() @@ -182,8 +235,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -234,12 +290,31 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv4(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -263,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in // response to a non-initial fragment, but it currently can not happen. - - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { return nil } @@ -272,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil sent := p.stack.Stats().ICMP.V4PacketsSent if !p.stack.AllowICMPMessage() { @@ -287,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. - if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { // TODO(gvisor.dev/issue/3810): // Unfortunately the current stack pretty much always has ICMPv4 headers // in the Data section of the packet but there is no guarantee that is the @@ -348,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -360,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // view with the entire incoming IP packet reassembled and truncated as // required. This is now the payload of the new ICMP packet and no longer // considered a packet in its own right. - newHeader := append(buffer.View(nil), networkHeader...) + newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) @@ -374,17 +444,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - switch reason.(type) { + var counter *tcpip.StatCounter + switch reason := reason.(type) { case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - counter := sent.DstUnreachable if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e7c58ae0a..cfd0c505a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -16,7 +16,9 @@ package ipv4 import ( + "errors" "fmt" + "math" "sync/atomic" "time" @@ -31,6 +33,8 @@ import ( ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. // As per RFC 791 section 3.2: // The current recommendation for the initial timer setting is 15 seconds. // This may be changed as experience with this protocol accumulates. @@ -38,7 +42,7 @@ const ( // Considering that it is an old recommendation, we use the same reassembly // timeout that linux defines, which is 30 seconds: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 - reassembleTimeout = 30 * time.Second + ReassembleTimeout = 30 * time.Second // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -211,18 +219,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) -} - // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) var n int for { @@ -247,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() return nil @@ -265,23 +269,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, pkt) - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -292,6 +313,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -311,17 +333,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - if e.packetMustBeFragmented(pkt, gso) { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt return nil }); err != nil { - panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) } // Remove the packet that was just fragmented and process the rest. pkts.Remove(originalPkt) @@ -355,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -385,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return tcpip.ErrMalformedHeader } + + hdrLen := header.IPv4(h).HeaderLength() + if hdrLen < header.IPv4MinimumSize { + return tcpip.ErrMalformedHeader + } + + h, ok = pkt.Data.PullUp(int(hdrLen)) + if !ok { + return tcpip.ErrMalformedHeader + } ip := header.IPv4(h) // Always set the total length. @@ -429,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -462,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -470,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -480,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -488,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -502,10 +545,30 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } + + // Set up a callback in case we need to send a Time Exceeded Message, as per + // RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards + // the datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } + } + } + var ready bool var err error proto := h.Protocol() @@ -523,29 +586,56 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.More(), proto, pkt.Data, + releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if !ready { return } + + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) } + stats.IP.PacketsDelivered.Increment() - r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport // headers, the setting of the transport number here should be // unnecessary and removed. pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt) + e.handleICMP(pkt) return } + if len(h.Options()) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + } - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination @@ -553,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -602,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo loopback := e.nic.IsLoopback() addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { - subnet := addressEndpoint.AddressWithPrefix().Subnet() + subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) @@ -778,26 +868,32 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return mtu - header.IPv4MinimumSize -} -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader } - mtu -= uint32(pkt.NetworkHeader().View().Size()) - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - return mtu + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize + } + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 @@ -836,7 +932,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), } } @@ -846,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head originalIPHeaderLength := len(originalIPHeader) nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) @@ -862,3 +959,324 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head return fragPkt, more } + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if nextSlot >= optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. + if nextSlot+header.IPv4AddressSize > optlen { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + stats := e.protocol.stack.Stats() + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = dstAddr + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + stats.IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := e.protocol.stack.Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + stats.IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + stats.IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index fee11bb38..c7f434591 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -21,11 +21,13 @@ import ( "math" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -39,7 +41,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -const extraHeaderReserve = 50 +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ @@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -103,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) { // checks the response. func TestIPv4Sanity(t *testing.T) { const ( - defaultMTU = header.IPv6MinimumMTU ttl = 255 nicID = 1 randomSequence = 123 @@ -118,27 +121,29 @@ func TestIPv4Sanity(t *testing.T) { ) tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - shouldFail bool - expectICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - options []byte + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options []byte + replyOptions []byte // if succeeds, reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 }{ { - name: "valid", - maxTotalLength: defaultMTU, + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, { name: "bad header checksum", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, badHeaderChecksum: true, @@ -157,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) { // received with TTL less than 2. { name: "zero TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, - shouldFail: false, }, { name: "one TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, - shouldFail: false, }, { name: "End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{0, 0, 0, 0}, + replyOptions: []byte{0, 0, 0, 0}, }, { name: "NOP options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 1, 1}, + replyOptions: []byte{1, 1, 1, 1}, }, { name: "NOP and End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 0, 0}, + replyOptions: []byte{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (0)", @@ -205,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip - 1)", @@ -213,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip + icmp - 1)", @@ -221,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad protocol", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, - expectICMP: true, + expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: []byte{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: []byte{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: []byte{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: []byte{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: []byte{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: []byte{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") + e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -250,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -312,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) { reply, ok := e.Read() if !ok { if test.shouldFail { - if test.expectICMP { - t.Fatal("expected ICMP error response missing") + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) @@ -328,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } - // Make sure it's all in one buffer. - vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) - replyIPHeader := header.IPv4(vv.ToView()) + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - // At this stage we only know it's an IP header so verify that much. + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), ) - // All expected responses are ICMP packets. - if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { - t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() } - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - // Sanity check the response. + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { - case header.ICMPv4DstUnreachable: + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Checksum(), + checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) - if !test.shouldFail || !test.expectICMP { - t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) return case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength), - checker.IPv4Options(test.options), - checker.IPFullLength(uint16(requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( + checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), - checker.ICMPv4Checksum(), ), ) - if test.shouldFail { - t.Fatalf("unexpected Echo Reply packet\n") - } default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } @@ -462,7 +840,7 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", + description: "No fragmentation", mtu: 1280, gso: nil, transportHeaderLength: 0, @@ -483,6 +861,30 @@ var fragmentationTests = []struct { }, }, { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { description: "No fragmentation with big header", mtu: 2000, gso: nil, @@ -647,43 +1049,50 @@ func TestFragmentationWritePackets(t *testing.T) { } } -// TestFragmentationErrors checks that errors are returned from write packet +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 - expectedError := tcpip.ErrAborted - fragTests := []struct { + tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int - fragmentCount int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ { description: "No frag", mtu: 2000, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 1, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 3, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on second frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 1, - fragmentCount: 3, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag MTU smaller than header", @@ -691,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, - fragmentCount: 4, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) - r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != expectedError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } @@ -744,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) { autoChecksum bool // if true, the Checksum field will be overwritten. } - // These packets have both IHL and TotalLength set to 0. tests := []struct { name string fragments []fragmentData @@ -984,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -1027,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) { } } +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + Clock: clock, + }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) + } + + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { @@ -1506,13 +2178,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1527,17 +2196,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1577,7 +2243,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -1783,7 +2449,7 @@ func TestPacketQueing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a30437f02..0ac24a6fb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ead6bedcb..8502b848c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { }) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := r.Stats().ICMP +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { + stats := e.protocol.stack.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their @@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) + srcAddr := iph.SourceAddress() + dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. // // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { received.Invalid.Increment() return } @@ -170,8 +172,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -221,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // we know we are also performing DAD on it). In this case we let the // stack know so it can handle such a scenario and do nothing further with // the NS. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -248,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -274,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. - unspecifiedSource := r.RemoteAddress == header.IPv6Any + unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { - if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { received.Invalid.Increment() return } @@ -284,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -295,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // ... // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. - if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { received.Invalid.Increment() return } @@ -305,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the source of the solicitation is the unspecified address, the node // MUST [...] and multicast the advertisement to the all-nodes address. // - remoteAddr := r.RemoteAddress + remoteAddr := srcAddr if unspecifiedSource { remoteAddr = header.IPv6AllNodesMulticastAddress } @@ -462,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. - localAddr := r.LocalAddress - if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr := dstAddr + if header.IsV6MulticastAddress(dstAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -483,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -495,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -516,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - stack := r.Stack() + stack := e.protocol.stack // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { @@ -547,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { received.Invalid.Increment() return } @@ -555,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -572,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - routerAddr := iph.SourceAddress() + routerAddr := srcAddr // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { @@ -605,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -648,52 +657,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := stack.Route{ - LocalAddress: localAddr, - RemoteAddress: addr, - LocalLinkAddress: linkEP.LinkAddress(), - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - // If a remote address is not already known, then send a multicast - // solicitation since multicast addresses have a static mapping to link - // addresses. - if len(r.RemoteLinkAddress) == 0 { - r.RemoteAddress = header.SolicitedNodeAddr(addr) - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err } + defer r.Release() + r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(packet.NDPPayload()) - ns.SetTargetAddress(addr) + ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) + stat := p.stack.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -747,9 +750,20 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv6(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -776,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { return nil } @@ -784,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil stats := p.stack.Stats().ICMP sent := stats.V6PacketsSent @@ -839,7 +850,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -860,6 +873,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8dc33c560..76013daa1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -51,6 +51,7 @@ const ( var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -86,7 +87,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { +func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { return stack.TransportPacketHandled } @@ -108,31 +109,27 @@ type stubNUDHandler struct { var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { s.probeCount++ } -func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { s.confirmationCount++ } -func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { } var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint - - linkAddr tcpip.LinkAddress -} + stack.LinkEndpoint -func (i *testInterface) LinkAddress() tcpip.LinkAddress { - return i.linkAddr + nicID tcpip.NICID } func (*testInterface) ID() tcpip.NICID { - return 0 + return nicID } func (*testInterface) IsLoopback() bool { @@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool { return true } +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -277,7 +282,8 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -419,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -1235,26 +1242,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + snaddr := header.SolicitedNodeAddr(lladdr0) mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectedLinkAddr tcpip.LinkAddress - expectedAddr tcpip.Address + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectedLinkAddr: linkAddr1, - expectedAddr: lladdr0, + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, }, { - name: "Multicast", - remoteLinkAddr: "", - expectedLinkAddr: mcaddr, - expectedAddr: snaddr, + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, }, } @@ -1269,26 +1322,43 @@ func TestLinkAddressRequest(t *testing.T) { } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return } pkt, ok := linkEP.Read() if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } - if pkt.Route.RemoteAddress != test.expectedAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) } if pkt.Route.LocalAddress != lladdr1 { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) } checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedAddr), + checker.DstAddr(test.expectedRemoteAddr), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(lladdr0), @@ -1698,7 +1768,7 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1728,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) { SrcAddr: r.RemoteAddress, DstAddr: r.LocalAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. if nudHandler.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9670696c7..0526190cc 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -41,12 +41,12 @@ const ( // // Linux also uses 60 seconds for reassembly timeout: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 - reassembleTimeout = 60 * time.Second + ReassembleTimeout = 60 * time.Second // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff @@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return err } - prefix := addressEndpoint.AddressWithPrefix().Subnet() + prefix := addressEndpoint.Subnet() switch t := addressEndpoint.ConfigType(); t { case stack.AddressConfigStatic: @@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and @@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. The transport -// header protocol number is required to avoid parsing the IPv6 extension -// headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - if fragMTU < pkt.TransportHeader().View().Size() { +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. return 0, 1, tcpip.ErrMessageTooLong } - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) - networkHeader := header.IPv6(pkt.NetworkHeader().View()) var n int for { @@ -448,28 +465,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ - // The inbound path expects an unparsed packet. - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) - - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -499,13 +528,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) - if e.packetMustBeFragmented(pb, gso) { + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -546,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -569,7 +607,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) @@ -607,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() // As per RFC 4291 section 2.7: // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. - if header.IsV6MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if header.IsV6MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -641,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -651,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -663,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -675,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -689,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -702,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -727,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -747,6 +790,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. @@ -762,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if done { @@ -790,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } } @@ -799,30 +844,70 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + return + } + + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) return } // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the fragment if the size of the reassembled payload would exceed - // the maximum payload size. + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } + // Set up a callback in case we need to send a Time Exceeded Message as + // per RFC 2460 Section 4.5. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } + } + } + // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ - Source: h.SourceAddress(), - Destination: h.DestinationAddress(), + Source: srcAddr, + Destination: dstAddr, ID: extHdr.ID(), }, start, @@ -830,10 +915,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.More(), uint8(rawPayload.Identifier), rawPayload.Buf, + releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } pkt.Data = data @@ -852,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -866,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -879,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -902,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt, hasFragmentHeader) + e.handleICMP(pkt, hasFragmentHeader) } else { - r.Stats().IP.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + stats.IP.PacketsDelivered.Increment() + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -916,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -937,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -947,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) - r.Stats().UnknownProtocolRcvdPackets.Increment() + stats.UnknownProtocolRcvdPackets.Increment() return } } @@ -1427,14 +1513,31 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize } - return maxPayloadSize + return networkMTU, nil } // Options holds options to configure a new protocol. @@ -1488,7 +1591,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), ids: ids, hashIV: hashIV, @@ -1509,23 +1612,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are - // supported for outbound packets, their length should not affect the fragment - // MTU because they should only be transmitted once. - mtu -= uint32(pkt.NetworkHeader().View().Size()) - mtu -= header.IPv6FragmentHeaderSize - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - if mtu <= maxPayloadSize { - return mtu - } - return maxPayloadSize -} - func calculateFragmentReserve(pkt *stack.PacketBuffer) int { return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize } @@ -1560,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea originalIPHeadersLength := len(originalIPHeaders) fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 297868f24..1bfcdde25 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" @@ -238,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -271,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -825,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1844,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1912,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - nicID = 1 - fragmentExtHdrLen = 8 + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" ) - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte } tests := []struct { @@ -1929,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 }{ { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { name: "fragments reassembled into a payload exceeding the max IPv6 payload size", fragments: []fragmentData{ { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, - []buffer.View{ - // Fragment extension header. - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, - ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, - 0, 0, 0, 1}), - // Payload length = 16 - payloadGen(16), - }, - ), + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ }, } @@ -1964,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) { NewProtocol, }, }) - e := channel.New(0, 1500, linkAddr1) + e := channel.New(1, 1500, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + var expectICMPPayload buffer.View for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() - vv.Append(f.data) + vv.AppendView(f.payload) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - })) + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { @@ -1999,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) { if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -2035,13 +2360,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2056,17 +2378,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2230,8 +2549,8 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", - mtu: 1280, + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 0, payloadSize: 1000, @@ -2241,7 +2560,18 @@ var fragmentationTests = []struct { }, { description: "Fragmented", - mtu: 1280, + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, gso: nil, transHdrLen: 0, payloadSize: 2000, @@ -2262,7 +2592,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with gso none", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, payloadSize: 1400, @@ -2273,7 +2603,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with big header", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 100, payloadSize: 1200, @@ -2448,8 +2778,8 @@ func TestFragmentationErrors(t *testing.T) { wantError: tcpip.ErrAborted, }, { - description: "Error on packet with MTU smaller than transport header", - mtu: 1280, + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, transHdrLen: 1500, payloadSize: 500, allowPackets: 0, @@ -2457,6 +2787,16 @@ func TestFragmentationErrors(t *testing.T) { mockError: nil, wantError: tcpip.ErrMessageTooLong, }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } for _, ft := range tests { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index ac20f217e..981d1371a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -341,7 +341,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } @@ -368,7 +368,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } if ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } } }) @@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) @@ -913,13 +920,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } if test.isValid { @@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } var tllData [header.NDPLinkLayerAddressSize]byte diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..9478f3fb7 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState = &addressState{ addressableEndpointState: a, addr: addr, + // Cache the subnet in addrState to avoid calls to addr.Subnet() as that + // results in allocations on every call. + subnet: addr.Subnet(), } a.mu.endpoints[addr.Address] = addrState addrState.mu.Lock() @@ -361,6 +364,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } @@ -664,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil) type addressState struct { addressableEndpointState *AddressableEndpointState addr tcpip.AddressWithPrefix - + subnet tcpip.Subnet // Lock ordering (from outer to inner lock ordering): // // AddressableEndpointState.mu @@ -684,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr } +// Subnet implements AddressEndpoint. +func (a *addressState) Subnet() tcpip.Subnet { + return a.subnet +} + // GetKind implements AddressEndpoint. func (a *addressState) GetKind() AddressKind { a.mu.RLock() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 0cd1da11f..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.Addr - replyTID.srcPort = rt.Port + replyTID.srcAddr = address + replyTID.srcPort = port var manip manipType switch hook { case Prerouting: @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..7a501acdc 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 8d6d9a7f1..2d8c883cd 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -22,30 +22,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -// tableID is an index into IPTables.tables. -type tableID int +// TableID identifies a specific table. +type TableID int +// Each value identifies a specific table. const ( - natID tableID = iota - mangleID - filterID - numTables + NATID TableID = iota + MangleID + FilterID + NumTables ) -// Table names. -const ( - NATTable = "nat" - MangleTable = "mangle" - FilterTable = "filter" -) - -// nameToID is immutable. -var nameToID = map[string]tableID{ - NATTable: natID, - MangleTable: mangleID, - FilterTable: filterID, -} - // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 @@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - v4Tables: [numTables]Table{ - natID: Table{ + v4Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -81,7 +68,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -99,7 +86,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -122,8 +109,8 @@ func DefaultTables() *IPTables { }, }, }, - v6Tables: [numTables]Table{ - natID: Table{ + v6Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -146,7 +133,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -164,7 +151,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -187,10 +174,10 @@ func DefaultTables() *IPTables { }, }, }, - priorities: [NumHooks][]tableID{ - Prerouting: []tableID{mangleID, natID}, - Input: []tableID{natID, filterID}, - Output: []tableID{mangleID, natID, filterID}, + priorities: [NumHooks][]TableID{ + Prerouting: []TableID{MangleID, NATID}, + Input: []TableID{NATID, FilterID}, + Output: []TableID{MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -229,26 +216,20 @@ func EmptyNATTable() Table { } } -// GetTable returns a table by name. -func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - id, ok := nameToID[name] - if !ok { - return Table{}, false - } +// GetTable returns a table with the given id and IP version. It panics when an +// invalid id is provided. +func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { it.mu.RLock() defer it.mu.RUnlock() if ipv6 { - return it.v6Tables[id], true + return it.v6Tables[id] } - return it.v4Tables[id], true + return it.v4Tables[id] } -// ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - id, ok := nameToID[name] - if !ok { - return tcpip.ErrInvalidOptionValue - } +// ReplaceTable replaces or inserts table by name. It panics when an invalid id +// is provided. +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to // check the NAT table. - if tableID == natID && pkt.NatDone { + if tableID == NATID && pkt.NatDone { continue } var table Table diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 538c4625d..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -26,13 +28,6 @@ type AcceptTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (at *AcceptTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: at.NetworkProtocol, - } -} - // Action implements Target.Action. func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 @@ -44,22 +39,11 @@ type DropTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (dt *DropTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: dt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } -// ErrorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const ErrorTargetName = "ERROR" - // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. type ErrorTarget struct { @@ -67,14 +51,6 @@ type ErrorTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (et *ErrorTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: et.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") @@ -90,14 +66,6 @@ type UserChainTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (uc *UserChainTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: uc.NetworkProtocol, - } -} - // Action implements Target.Action. func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") @@ -110,50 +78,39 @@ type ReturnTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *ReturnTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } -// RedirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const RedirectTargetName = "REDIRECT" - -// RedirectTarget redirects the packet by modifying the destination port/IP. +// RedirectTarget redirects the packet to this machine by modifying the +// destination port/IP. Outgoing packets are redirected to the loopback device, +// and incoming packets are redirected to the incoming interface (rather than +// forwarded). +// // TODO(gvisor.dev/issue/170): Other flags need to be added after we support // them. type RedirectTarget struct { - // Addr indicates address used to redirect. - Addr tcpip.Address - - // Port indicates port used to redirect. + // Port indicates port used to redirect. It is immutable. Port uint16 - // NetworkProtocol is the network protocol the target is used with. + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *RedirectTarget) ID() TargetID { - return TargetID{ - Name: RedirectTargetName, - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither +// implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + address = tcpip.Address([]byte{127, 0, 0, 1}) } else { - rt.Addr = header.IPv6Loopback + address = header.IPv6Loopback } case Prerouting: - rt.Addr = address + // No-op, as address is already set correctly. default: panic("redirect target is supported only on output and prerouting hooks") } @@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(rt.Addr) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7b3f3e88b..4b86c1be9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -37,7 +37,6 @@ import ( // ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> type Hook uint -// These values correspond to values in include/uapi/linux/netfilter.h. const ( // Prerouting happens before a packet is routed to applications or to // be forwarded. @@ -86,8 +85,8 @@ type IPTables struct { mu sync.RWMutex // v4Tables and v6tables map tableIDs to tables. They hold builtin // tables only, not user tables. mu must be locked for accessing. - v4Tables [numTables]Table - v6Tables [numTables]Table + v4Tables [NumTables]Table + v6Tables [NumTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. @@ -96,7 +95,7 @@ type IPTables struct { // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. - priorities [NumHooks][]tableID + priorities [NumHooks][]TableID connections ConnTrack @@ -104,6 +103,24 @@ type IPTables struct { reaperDone chan struct{} } +// VisitTargets traverses all the targets of all tables and replaces each with +// transform(target). +func (it *IPTables) VisitTargets(transform func(Target) Target) { + it.mu.Lock() + defer it.mu.Unlock() + + for tid := range it.v4Tables { + for i, rule := range it.v4Tables[tid].Rules { + it.v4Tables[tid].Rules[i].Target = transform(rule.Target) + } + } + for tid := range it.v6Tables { + for i, rule := range it.v6Tables[tid].Rules { + it.v6Tables[tid].Rules[i].Target = transform(rule.Target) + } + } +} + // A Table defines a set of chains and hooks into the network stack. // // It is a list of Rules, entry points (BuiltinChains), and error handlers @@ -169,7 +186,6 @@ type IPHeaderFilter struct { // CheckProtocol determines whether the Protocol field should be // checked during matching. - // TODO(gvisor.dev/issue/3549): Check this field during matching. CheckProtocol bool // Dst matches the destination IP address. @@ -309,23 +325,8 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } -// A TargetID uniquely identifies a target. -type TargetID struct { - // Name is the target name as stored in the xt_entry_target struct. - Name string - - // NetworkProtocol is the protocol to which the target applies. - NetworkProtocol tcpip.NetworkProtocolNumber - - // Revision is the version of the target. - Revision uint8 -} - // A Target is the interface for taking an action for a packet. type Target interface { - // ID uniquely identifies the Target. - ID() TargetID - // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..177bf5516 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "time" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -68,7 +67,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +83,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,28 +110,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), + Addr: remoteAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: 0, } return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -207,7 +209,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,8 +222,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. @@ -292,8 +293,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..ed33418f3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -61,23 +61,20 @@ const ( ) // entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAt field is ignored due to a lack of a deterministic -// method to predict the time that an event will be dispatched. +// entries. The UpdatedAtNanos field is ignored due to a lack of a +// deterministic method to predict the time that an event will be dispatched. func entryDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // entryDiffOptsWithSort is like entryDiffOpts but also includes an option to // sort slices of entries for cases where ordering must be ignored. func entryDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + })) } func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { @@ -128,9 +125,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +191,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +290,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +300,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +323,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +353,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +364,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +394,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +409,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +457,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +475,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +519,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +574,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +655,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +699,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +722,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +761,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +825,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +864,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +893,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +927,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +944,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +954,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +999,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1026,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1067,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1079,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1098,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1128,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1137,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1168,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1193,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1233,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1271,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1327,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1369,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1378,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1398,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1436,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1448,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1490,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1515,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1546,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1599,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1656,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1713,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1737,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1772,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..493e48031 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -24,13 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +const ( + // immediateDuration is a duration of zero for scheduling work that needs to + // be done immediately but asynchronously to avoid deadlock. + immediateDuration time.Duration = 0 +) + // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LocalAddr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAtNanos int64 } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -106,35 +111,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAt = time.Now() + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,24 +238,23 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } - sendUnicastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(immediateDuration) case Failed: e.notifyWakersLocked() @@ -315,15 +275,77 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(immediateDuration) case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() case Incomplete, Reachable, Delay, Probe, Static, Failed: // Do nothing @@ -345,21 +367,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +415,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +432,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +442,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +497,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..c2b763325 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -47,24 +47,27 @@ const ( entryTestNetDefaultMTU = 65536 ) +// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { + clock.Advance(immediateDuration) +} + // eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAt field is ignored due to a lack of a deterministic method to -// predict the time that an event will be dispatched. +// The UpdatedAtNanos field is ignored due to a lack of a deterministic method +// to predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -125,14 +128,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +150,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +193,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -245,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -323,15 +314,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { func TestEntryUnknownToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -350,9 +342,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -367,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { func TestEntryUnknownToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) @@ -377,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Unlock() // No probes should have been sent. + runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) linkRes.mu.Unlock() @@ -388,9 +383,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,11 +403,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - updatedAt := e.neigh.UpdatedAt + updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -437,7 +434,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -468,16 +465,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -487,7 +488,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() @@ -495,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { func TestEntryIncompleteToReachable(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -526,20 +520,35 @@ func TestEntryIncompleteToReachable(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -555,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { // to Reachable. func TestEntryAddsAndClearsWakers(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) w := sleep.Waker{} s := sleep.Sleeper{} @@ -563,7 +572,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -587,34 +614,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -626,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -659,20 +666,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { } linkRes.mu.Unlock() + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -684,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { func TestEntryIncompleteToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -715,20 +733,35 @@ func TestEntryIncompleteToStale(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +777,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +816,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -817,12 +854,30 @@ func (*testLocker) Unlock() {} func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -848,34 +903,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -893,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -928,20 +959,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,17 +1014,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -986,29 +1032,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1026,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1058,27 +1110,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1086,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1132,27 +1184,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1160,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1206,27 +1262,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1234,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1279,20 +1340,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1304,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1343,27 +1408,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1375,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } e.mu.Lock() - e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1400,41 +1511,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1446,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1485,27 +1570,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1517,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1552,27 +1651,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1584,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { func TestEntryStaleToDelay(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1616,27 +1728,48 @@ func TestEntryStaleToDelay(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,22 +1789,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleUpperLevelConfirmationLocked() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1686,43 +1807,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleUpperLevelConfirmationLocked() + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,29 +1889,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1780,43 +1907,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1996,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1860,57 +2037,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1922,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1962,27 +2115,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1994,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2027,34 +2197,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2066,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2103,34 +2281,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,69 +2351,91 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,36 +2456,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2274,37 +2516,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2312,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { @@ -2325,36 +2571,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2375,37 +2635,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2413,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { @@ -2426,36 +2690,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2479,30 +2758,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,17 +2816,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - // Probe caused by the Delay-to-Probe transition { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2567,42 +2851,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,36 +2915,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2672,49 +2979,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,36 +3052,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2781,49 +3113,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,36 +3186,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2890,49 +3247,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3314,116 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3034,12 +3431,6 @@ func TestEntryProbeToFailed(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,84 +3445,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The next three probe are sent in Probe. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..60c81a3aa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -339,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -546,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { } func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + n.getNetworkEndpoint(protocol).HandlePacket(pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -585,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] @@ -660,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.nic + n := r.localAddressNIC if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote + pkt.NICID = n.ID() r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + pkt.NetworkPacketInfo = r.networkPacketInfo() + n.getNetworkEndpoint(protocol).HandlePacket(pkt) addressEndpoint.DecRef() r.Release() return @@ -678,7 +697,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease + // the TTL field for ipv4/ipv6. // pkt may have set its header and may not have enough headroom for // link-layer header for the other link to prepend. Here we create a new @@ -725,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -737,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -766,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -781,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -885,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { @@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..b8f333057 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,28 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // RemoteAddressBroadcast is true if the packet's remote address is a + // broadcast address. + RemoteAddressBroadcast bool + + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +116,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -172,9 +183,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +238,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface { // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix + // Subnet returns the subnet of the endpoint's address. + Subnet() tcpip.Subnet + // IsAssigned returns whether or not the endpoint is considered bound // to its NetworkEndpoint. IsAssigned(allowExpired bool) bool @@ -490,6 +504,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -544,7 +561,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -764,13 +781,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index b76e2d37b..15ff437c7 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,11 +47,16 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +67,144 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + addrWithPrefix := addressEndpoint.AddressWithPrefix() + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + addressEndpoint.DecRef() + return Route{} + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = addrWithPrefix.Address + } + + r := makeRoute( + netProto, + addrWithPrefix.Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, - Loop: loop, + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + localAddressEndpoint: localAddressEndpoint, + outgoingNIC: outgoingNIC, + Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// PopulatePacketInfo populates a packet buffer's packet information fields. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { + if r.local() { + pkt.RXTransportChecksumValidated = true + } + pkt.NetworkPacketInfo = r.networkPacketInfo() +} + +// networkPacketInfo returns the network packet information of the route. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) networkPacketInfo() NetworkPacketInfo { + return NetworkPacketInfo{ + RemoteAddressBroadcast: r.IsOutboundBroadcast(), + LocalAddressBroadcast: r.isInboundBroadcast(), + } +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } @@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } @@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before the this route can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + if r.localAddressEndpoint != nil { + r.localAddressEndpoint.DecRef() + r.localAddressEndpoint = nil } } // Clone clones the route. func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() + if r.localAddressEndpoint != nil { + if !r.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) + } } return *r } @@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + subnet := r.localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool { return r.isV4Broadcast(r.RemoteAddress) } -// IsInboundBroadcast returns true if the route is for an inbound broadcast +// isInboundBroadcast returns true if the route is for an inbound broadcast // packet. -func (r *Route) IsInboundBroadcast() bool { +func (r *Route) isInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.LocalAddress) } @@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool { // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + localAddressEndpoint: r.localAddressEndpoint, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -518,6 +519,10 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // IPTables are the initial iptables rules. If nil, iptables will allow + // all traffic. + IPTables *IPTables } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +625,10 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } + if opts.IPTables == nil { + opts.IPTables = DefaultTables() + } + opts.NUDConfigs.resetInvalidFields() s := &Stack{ @@ -633,7 +642,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - tables: DefaultTables(), + tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, @@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -830,6 +839,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1180,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } + + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1323,7 +1517,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. @@ -1443,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1896,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..dedfdd435 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + pkt := pkt.Clone() + r.PopulatePacketInfo(pkt) + f.HandlePacket(pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { RemoteAddress: ipv4SubnetBcast, RemoteLinkAddress: header.EthernetBroadcastAddress, NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, + Loop: stack.PacketOut | stack.PacketLoop, }, }, // Broadcast to a locally attached /31 subnet does not populate the @@ -3672,3 +3778,453 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + defer r.Release() + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..c457b67a2 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -28,7 +28,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts @@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return + } + route.ResolveWith(pkt.SourceLinkAddress()) + + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + }) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d77848d61..3ab2b7654 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -763,6 +762,10 @@ const ( // endpoint that all packets being written have an IP header and the // endpoint should not attach an IP header. IPHdrIncludedOption + + // AcceptConnOption is used by GetSockOptBool to indicate if the + // socket is a listening socket. + AcceptConnOption ) // SockOptInt represents socket options which values have the int type. @@ -1256,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -1496,6 +1505,15 @@ type IPStats struct { // IPTablesOutputDropped is the total number of IP packets dropped in // the Output chain. IPTablesOutputDropped *StatCounter + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived *StatCounter + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived *StatCounter + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 34aab32d0..9b0f3b675 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -10,6 +10,7 @@ go_test( "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", + "route_test.go", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0dcef7b04..bf7594268 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -33,11 +33,6 @@ import ( func TestForwarding(t *testing.T) { const ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - host1NICID = 1 routerNICID1 = 2 routerNICID2 = 3 @@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) { } }, }, + { + name: "IPv4 host2 server with routerNIC1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 routerNIC2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, } for _, test := range tests { @@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) { if err == tcpip.ErrNoLinkAddress { // Wait for link resolution to complete. <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } else if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) } - if err != nil { t.Fatalf("ep.Write(_, _): %s", err) } @@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress v, _, err := ep.Read(&addr) if err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 6ddcda70c..fe7c1bb3d 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -32,32 +32,36 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +const ( + linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +) - host1IPv4Addr = tcpip.ProtocolAddress{ +var ( + ipv4Addr1 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - host2IPv4Addr = tcpip.ProtocolAddress{ + ipv4Addr2 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr2 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), @@ -89,7 +93,7 @@ func TestPing(t *testing.T) { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, - remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) @@ -104,7 +108,7 @@ func TestPing(t *testing.T) { name: "IPv6 Ping", transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, - remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) @@ -127,7 +131,7 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -143,36 +147,36 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, tcpip.Route{ - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, tcpip.Route{ - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, }) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index e8caf09ba..421da1add 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }, }) - wq := waiter.Queue{} + var wq waiter.Queue rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index f1028823b..cdf0459e3 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - wq := waiter.Queue{} + var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) @@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) { loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") ) - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) - tests := []struct { name string broadcastAddr tcpip.Address @@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, }) + type endpointAndWaiter struct { + ep tcpip.Endpoint + ch chan struct{} + } + var eps []endpointAndWaiter // We create endpoints that bind to both the wildcard address and the // broadcast address to make sure both of these types of "broadcast // interested" endpoints receive broadcast packets. - wq := waiter.Queue{} - var eps []tcpip.Endpoint for _, bindWildcard := range []bool{false, true} { // Create multiple endpoints for each type of "broadcast interested" // endpoint so we can test that all endpoints receive the broadcast // packet. for i := 0; i < 2; i++ { + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) @@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } } - eps = append(eps, ep) + eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) } } @@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - if n, _, err := wep.Write(data, writeOpts); err != nil { + data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) + if n, _, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) } for j, rep := range eps { - if gotPayload, _, err := rep.Read(nil); err != nil { + // Wait for the endpoint to become readable. + <-rep.ch + + if gotPayload, _, err := rep.ep.Read(nil); err != nil { t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go new file mode 100644 index 000000000..02fc47015 --- /dev/null +++ b/pkg/tcpip/tests/integration/route_test.go @@ -0,0 +1,388 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLocalPing tests pinging a remote that is local the stack. +// +// This tests that a local route is created and packets do not leave the stack. +func TestLocalPing(t *testing.T) { + const ( + nicID = 1 + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } + channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { + channelEP := e.(*channel.Endpoint) + if n := channelEP.Drain(); n != 0 { + t.Fatalf("got channelEP.Drain() = %d, want = 0", n) + } + } + + ipv4ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + ipv6ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + linkEndpoint func() stack.LinkEndpoint + localAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + expectedConnectErr *tcpip.Error + checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) + }{ + { + name: "IPv4 loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: ipv4Loopback, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: header.IPv6Loopback, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv4Addr.Address, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv6Addr.Address, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv4 loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } + + if test.expectedConnectErr != nil { + return + } + + payload := tcpip.SlicePayload(test.icmpBuf(t)) + var wOpts tcpip.WriteOptions + if n, _, err := ep.Write(payload, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } + + // Wait for the endpoint to become readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } +} + +// TestLocalUDP tests sending UDP packets between two endpoints that are local +// to the stack. +// +// This tests that that packets never leave the stack and the addresses +// used when sending a packet. +func TestLocalUDP(t *testing.T) { + const ( + nicID = 1 + ) + + tests := []struct { + name string + canBePrimaryAddr tcpip.ProtocolAddress + firstPrimaryAddr tcpip.ProtocolAddress + }{ + { + name: "IPv4", + canBePrimaryAddr: ipv4Addr1, + firstPrimaryAddr: ipv4Addr2, + }, + { + name: "IPv6", + canBePrimaryAddr: ipv6Addr1, + firstPrimaryAddr: ipv6Addr2, + }, + } + + subTests := []struct { + name string + addAddress bool + expectedWriteErr *tcpip.Error + }{ + { + name: "Unassigned local address", + addAddress: false, + expectedWriteErr: tcpip.ErrNoRoute, + }, + { + name: "Assigned local address", + addAddress: true, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + s := stack.New(stackOpts) + ep := channel.New(1, header.IPv6MinimumMTU, "") + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if subTest.addAddress { + if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + } + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer server.Close() + + bindAddr := tcpip.FullAddress{Port: 80} + if err := server.Bind(bindAddr); err != nil { + t.Fatalf("server.Bind(%#v): %s", bindAddr, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventIn) + client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer client.Close() + + serverAddr := tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + Port: 80, + } + + clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &serverAddr, + } + if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + } else if subTest.expectedWriteErr != nil { + // Nothing else to test if we expected not to be able to send the + // UDP packet. + return + } else if n != int64(len(clientPayload)) { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + } + } + + // Wait for the server endpoint to become readable. + <-serverCH + + var clientAddr tcpip.FullAddress + if v, _, err := server.Read(&clientAddr); err != nil { + t.Fatalf("server.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + } + if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { + t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + } + if t.Failed() { + t.FailNow() + } + } + + serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &clientAddr, + } + if n, _, err := server.Write(serverPayload, wOpts); err != nil { + t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) + } else if n != int64(len(serverPayload)) { + t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + } + } + + // Wait for the client endpoint to become readable. + <-clientCH + + var gotServerAddr tcpip.FullAddress + if v, _, err := client.Read(&gotServerAddr); err != nil { + t.Fatalf("client.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + } + if gotServerAddr.Addr != serverAddr.Addr { + t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + } + if t.Failed() { + t.FailNow() + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 41eb0ca44..763cd8f84 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil default: @@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, }, } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 87d510f96..3820e5dc7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 072601d2d..31831a6d8 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrNotSupported + switch opt { + case tcpip.AcceptConnOption: + return false, nil + default: + return false, tcpip.ErrNotSupported + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e37c00523..7b6a87ba9 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.IPHdrIncludedOption: @@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { return } + remoteAddr := pkt.Network().SourceAddress() + if e.bound { // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != route.NICID() { + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() return } @@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != route.RemoteAddress { + if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() return } @@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // Push new packet into receive list and increment the buffer size. packet := &rawPacket{ senderAddr: tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress, + NIC: pkt.NICID, + Addr: remoteAddr, }, } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 33bfb56cd..7d97cbdc7 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { } // beforeSave is invoked by stateify. -func (ep *endpoint) beforeSave() { +func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming // packets. - ep.rcvMu.Lock() + e.rcvMu.Lock() } // saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 + e.rcvBufSizeMax = 0 // Release the lock acquired in beforeSave() so regular endpoint closing // logic can proceed after save. - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return max } // loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max } // afterLoad is invoked by stateify. -func (ep *endpoint) afterLoad() { - stack.StackFromEnv.RegisterRestoredEndpoint(ep) +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s // If the endpoint is connected, re-connect. - if ep.connected { + if e.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) if err != nil { panic(err) } } // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + if e.bound { + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if ep.associated { - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + if e.associated { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b706438bd..47982ca41 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { - netProto = s.route.NetProto + netProto = s.netProto } + + route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(s.remoteLinkAddr) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6Only n.ID = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.boundNICID = s.nicID + n.route = route + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} n.rcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + if err != nil { + return nil, err + } // Lock the endpoint before registering to ensure that no out of // band changes are possible due to incoming packets etc till @@ -425,20 +435,17 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer ctx.synRcvdCount.dec() - defer func() { - e.mu.Lock() - e.decSynRcvdCount() - e.mu.Unlock() - }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() + e.decSynRcvdCount() return } ctx.removePendingEndpoint(n) + e.decSynRcvdCount() n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() @@ -456,7 +463,9 @@ func (e *endpoint) incSynRcvdCount() bool { } func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() e.synRcvdCount-- + e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { @@ -468,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -478,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } switch { @@ -493,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. - return + return nil } ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } else { // If cookies are in use but the endpoint accept queue // is full then drop the syn. @@ -507,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Send SYN without window scaling because we currently // don't encode this information in the cookie. // @@ -524,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, s.route), + MSS: calculateAdvertisedMSS(e.userMSS, route), } - e.sendSynTCP(&s.route, tcpFields{ + fields := tcpFields{ id: s.id, ttl: e.ttl, tos: e.sendTOS, @@ -534,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { seq: cookie, ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, - }, synOpts) + } + if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + return err + } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil } case (s.flags & header.TCPFlagAck) != 0: @@ -548,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } if !ctx.synRcvdCount.synCookiesInUse() { @@ -567,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } iss := s.ackNumber - 1 @@ -588,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -609,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + if err != nil { + return err + } n.mu.Lock() @@ -623,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return nil } // Register new endpoint so that packets are routed to it. @@ -633,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return err } n.isRegistered = true @@ -671,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) + return nil + + default: + return nil } } // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.v6only ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) @@ -715,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { - return nil + return } if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { s := e.segmentQueue.dequeue() - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } close(e.drainDone) @@ -739,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { break } - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 0aaef495d..2facbebec 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { MSS: amss, } if ttl == 0 { - ttl = s.route.DefaultTTL() + ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } // Wait for notification. - index, _ = s.Fetch(true) + h.ep.mu.Unlock() + index, _ = s.Fetch(true /* block */) + h.ep.mu.Lock() } } @@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error { }, synOpts) for h.state != handshakeCompleted { + // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held + // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true) + index, _ := s.Fetch(true /* block */) h.ep.mu.Lock() switch index { @@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // TCP header, then the kernel calculate a checksum of the // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) - } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + } else if r.RequiresTXTransportChecksum() { xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } if ep == nil { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) s.decRef() return } @@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ drained := e.drainDone != nil if drained { close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } // Set up the functions that will be called when the main protocol loop @@ -1535,7 +1541,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} } for _, netProto := range netProtos { - if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil { tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { @@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() switch v { case newSegment: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 98aecab9e..21162f01a 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -172,10 +172,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newSegment(r, id, pkt) - if !s.parse() { + + s := newIncomingSegment(id, pkt) + if !s.parse(pkt.RXTransportChecksumValidated) { ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 560b4904c..a6f25896b 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) { testV6Connect(t, c) } +func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: defaultMTU, + }) + defer c.Cleanup() + + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3bcd3923a..258f9f1bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -721,9 +721,9 @@ func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned // by another user goroutine if not then we spin, otherwise - // we just goto sleep on the Lock() and wait. + // we just go to sleep on the Lock() and wait. if !e.mu.TryLock() { - // If socket is owned by the user then just goto sleep + // If socket is owned by the user then just go to sleep // as the lock could be held for a reasonably long time. if atomic.LoadUint32(&e.ownedByUser) == 1 { e.mu.Lock() @@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) + s := newOutgoingSegment(e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.MulticastLoopOption: return true, nil + case tcpip.AcceptConnOption: + e.LockUser() + defer e.UnlockUser() + + return e.EndpointState() == StateListen, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -2310,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // done yet) or the reservation was freed between the check above and // the FindTransportEndpoint below. But rather than retry the same port // we just skip it and move on. - transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID()) if transEP == nil { // ReservePort failed but there is no registered endpoint with // demuxer. Which indicates there is at least some endpoint that has @@ -2379,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.ID - s.route = r.Clone() e.sndWaker.Assert() } } @@ -2445,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.ID, nil) + s := newOutgoingSegment(e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ // Mark endpoint as closed. @@ -2627,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + + // Expand netProtos to include v4 and v6 under dual-stack if the caller is + // binding to a wildcard (empty) address, and this is an IPv6 endpoint with + // v6only set to false. + if netProto == header.IPv6ProtocolNumber { + stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) + alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + if alsoBindToV4 { + netProtos = append(netProtos, header.IPv4ProtocolNumber) } } @@ -2715,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { } } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -3074,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() { } func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() - } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + } else if e.route.HasSoftwareGSOCapability() { e.gso = &stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b25431467..2bcc5e1c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() { switch { case epState == StateInitial || epState == StateBound: case epState.connected() || epState.handshake(): - if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { - if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + if !e.route.HasSaveRestoreCapability() { + if !e.route.HasDisconncetOkCapability() { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 070b634b4..0664789da 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -30,6 +30,8 @@ import ( // The canonical way of using it is to pass the Forwarder.HandlePacket function // to stack.SetTransportProtocolHandler. type Forwarder struct { + stack *stack.Stack + maxInFlight int handler func(*ForwarderRequest) @@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ + stack: s, maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), @@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newSegment(r, id, pkt) +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newIncomingSegment(id, pkt) defer s.decRef() // We only care about well-formed SYN packets. - if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn { return false } @@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) { delete(r.forwarder.inFlight, r.segment.id) r.forwarder.mu.Unlock() - // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 5bce73605..2329aca4b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(r, ep, id, pkt) +func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(ep, id, pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newSegment(r, id, pkt) +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + s := newIncomingSegment(id, pkt) defer s.decRef() - if !s.parse() || !s.csumValid { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } if !s.flagIsSet(header.TCPFlagRst) { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(p.stack, s, stack.DefaultTOS, 0) } return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment, tos, ttl uint8) { +// +// If the passed TTL is 0, then the route's default TTL will be used. +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { + route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, tcpFields{ + + if ttl == 0 { + ttl = route.DefaultTTL() + } + + return sendTCP(&route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 7ef2df377..833a7b470 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { return found } -// Dump prints the state of the scoreboard structure. +// String returns human-readable state of the scoreboard structure. func (s *SACKScoreboard) String() string { var str strings.Builder str.WriteString("SACKScoreboard: {") diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1f9c5cf50..2091989cc 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -45,9 +46,18 @@ type segment struct { ep *endpoint qFlags queueFlags id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + + // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of + // individual members for link/network packet info. + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID + remoteLinkAddr tcpip.LinkAddress + + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -76,11 +86,16 @@ type segment struct { acked bool } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, + remoteLinkAddr: pkt.SourceLinkAddress(), } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB return s } -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, - route: r.Clone(), } s.rcvdTime = time.Now() if len(v) != 0 { @@ -110,7 +124,9 @@ func (s *segment) clone() *segment { ackNumber: s.ackNumber, flags: s.flags, window: s.window, - route: s.route.Clone(), + netProto: s.netProto, + nicID: s.nicID, + remoteLinkAddr: s.remoteLinkAddr, viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, @@ -160,7 +176,6 @@ func (s *segment) decRef() { panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) } } - s.route.Release() } } @@ -198,10 +213,10 @@ func (s *segment) segMemSize() int { // // Returns boolean indicating if the parsing was successful. // -// If checksum verification is not offloaded then parse also verifies the +// If checksum verification may not be skipped, parse also verifies the // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. -func (s *segment) parse() bool { +func (s *segment) parse(skipChecksumValidation bool) bool { // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -220,16 +235,14 @@ func (s *segment) parse() bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - // Query the link capabilities to decide if checksum validation is - // required. verifyChecksum := true - if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + if skipChecksumValidation { s.csumValid = true verifyChecksum = false } if verifyChecksum { s.csum = s.hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6fa8d63cd..ab5fa4fb7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 { + return + } + // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a7149efd0..5f05608e2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) { } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendPacket(nil, &context.Headers{ @@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki } func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendV6Packet(nil, &context.Headers{ @@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) { // Test acceptance. // Start listening. - listenBacklog := 2 + listenBacklog := 10 if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } time.Sleep(50 * time.Millisecond) @@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, + SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(789), @@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { @@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } + // Send a SYN request w/ sequence number higher than // the highest sequence number sent. iss = seqnum.Value(792) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4d7847142..f791f8f13 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -112,6 +112,18 @@ type Headers struct { TCPOpts []byte } +// Options contains options for creating a new test context. +type Options struct { + // EnableV4 indicates whether IPv4 should be enabled. + EnableV4 bool + + // EnableV6 indicates whether IPv4 should be enabled. + EnableV6 bool + + // MTU indicates the maximum transmission unit on the link layer. + MTU uint32 +} + // Context provides an initialized Network stack and a link layer endpoint // for use in TCP tests. type Context struct { @@ -154,10 +166,30 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + return NewWithOpts(t, Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, }) +} + +// NewWithOpts allocates and initializes a test context containing a new +// stack and a link-layer endpoint with specific options. +func NewWithOpts(t *testing.T, opts Options) *Context { + if opts.MTU == 0 { + panic("MTU must be greater than 0") + } + + stackOpts := stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + if opts.EnableV4 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) + } + if opts.EnableV6 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) + } + s := stack.New(stackOpts) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB @@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - ep := channel.New(1000, mtu, "") + ep := channel.New(1000, opts.MTU, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) + wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) } opts2 := stack.NICOptions{Name: "nic2"} if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } + var routeTable []tcpip.Route - s.SetRouteTable([]tcpip.Route{ - { + if opts.EnableV4 { + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, - }, - { + }) + } + + if opts.EnableV6 { + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, NIC: 1, - }, - }) + }) + } + + s.SetRouteTable(routeTable) return &Context{ t: t, @@ -373,6 +410,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, const icmpv4VariableHeaderOffset = 4 copy(icmp[icmpv4VariableHeaderOffset:], p1) copy(icmp[header.ICMPv4PayloadOffset:], p2) + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) // Inject packet. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d57ed5d79..9bcb918bb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } + if to.Port == 0 { + // Port 0 is an invalid port to send to. + return 0, nil, tcpip.ErrInvalidEndpointState + } + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -895,6 +900,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil + case tcpip.AcceptConnOption: + return false, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -1009,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + if r.RequiresTXTransportChecksum() && (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { @@ -1366,6 +1374,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.rcvMu.Unlock() } + e.lastErrorMu.Lock() + hasError := e.lastError != nil + e.lastErrorMu.Unlock() + if hasError { + result |= waiter.EventErr + } return result } @@ -1373,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // On IPv4, UDP checksum is optional, and a zero value means the transmitter // omitted the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). -func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) +func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { + if !pkt.RXTransportChecksumValidated && + (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { + netHdr := pkt.Network() + xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) } @@ -1387,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { @@ -1397,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - if !verifyChecksum(r, hdr, pkt) { + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() @@ -1428,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, @@ -1438,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvBufSize += pkt.Data.Size() // Save any useful information from the network header to the packet. - switch r.NetProto { + switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: @@ -1448,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast // address. packetInfo.LocalAddr should hold a unicast address that can be // used to respond to the incoming packet. - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.LocalAddress - packet.packetInfo.NIC = r.NICID() + localAddr := pkt.Network().DestinationAddress() + packet.packetInfo.LocalAddr = localAddr + packet.packetInfo.DestinationAddr = localAddr + packet.packetInfo.NIC = pkt.NICID packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() @@ -1465,14 +1481,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() - defer e.mu.RUnlock() - if e.state == StateConnected { e.lastErrorMu.Lock() - defer e.lastErrorMu.Unlock() - e.lastError = tcpip.ErrConnectionRefused + e.lastErrorMu.Unlock() + e.mu.RUnlock() + + e.waiterQueue.Notify(waiter.EventErr) + return } + e.mu.RUnlock() } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 3ae6cc221..14e4648cd 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, - route: r, id: id, pkt: pkt, }) @@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p // it via CreateEndpoint. type ForwarderRequest struct { stack *stack.Stack - route *stack.Route id stack.TransportEndpointID pkt *stack.PacketBuffer } @@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + netHdr := r.pkt.Network() + route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(r.pkt.SourceLinkAddress()) + + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() + route.Release() return nil, err } ep.ID = r.id - ep.route = r.route.Clone() + ep.route = route ep.dstPort = r.id.RemotePort - ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} - ep.RegisterNICID = r.route.NICID() + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} + ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.pkt) + ep.HandlePacket(r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index da5b1deb2..91420edd3 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets that are targeted at this // protocol but don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } - if !verifyChecksum(r, hdr, pkt) { - r.Stack().Stats().UDP.ChecksumErrors.Increment() + if !verifyChecksum(hdr, pkt) { + p.stack.Stats().UDP.ChecksumErrors.Increment() return stack.UnknownDestinationPacketMalformed } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b4604ba35..fb7738dda 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) { // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } // In the case of large payloads the IP packet may be truncated. Update diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index 5c4b9e8e9..a38ffc19d 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -53,40 +53,40 @@ func randomFilename() (string, error) { func TestConnectFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") + t.Fatalf("Connect was successful, expected err") } } func TestBindFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") + t.Fatalf("Second bind succeeded, expected non-nil err") } } func TestMultipleAccept(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() @@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) { defer wg.Done() s, err := Connect(name, false) if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) + t.Errorf("Connect failed, got err %v expected nil", err) + return } s.Close() }() @@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) { for i := 0; i < backlog; i++ { s, err := ss.Accept() if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) + t.Errorf("Accept failed, got err %v expected nil", err) continue } s.Close() @@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) { func TestServerClose(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } // Make sure the first close succeeds. if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } // Bind a server. ss, err := BindAndListen(name, packet) if err != nil { - t.Fatalf("error binding, got %v expected nil", err) + t.Fatalf("Error binding, got %v expected nil", err) } defer ss.Close() @@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { // Connect the client. client, err := Connect(name, packet) if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) + t.Fatalf("Error connecting, got %v expected nil", err) } // Grab the server handle. @@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { case server := <-acceptSocket: return server, client case err := <-acceptErr: - t.Fatalf("accept error: %v", err) + t.Fatalf("Accept error: %v", err) } panic("unreachable") } @@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. b := [][]byte{{'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) { // Write on the server. w := server.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -233,13 +234,13 @@ func TestPacket(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Write on the client again. w = client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. @@ -249,19 +250,19 @@ func TestPacket(t *testing.T) { b := [][]byte{{'b', 'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Do it again. r = server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -271,12 +272,12 @@ func TestClose(t *testing.T) { // Make sure the first close succeeds. if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } @@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) { // We're good. That's what we wanted. blockCount++ } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err) } } } if blockCount == 1000 { // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") + t.Fatalf("Socket always blocked!") } else if blockCount == 0 { // Should have started blocking eventually. - t.Fatalf("socket never blocked!") + t.Fatalf("Socket never blocked!") } } @@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) { // Expected to block immediately. _, err := r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } // Put some data in the pipe. w := server.Writer(false) if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it not to block. if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it to return a block error again. r = client.Reader(false) _, err = r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } } @@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c'}, {'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) } } @@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c', 'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) } } @@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) { w := server.Writer(true) w.PackFDs(0, 1, 2) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client, without enabling FDs. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Make sure the FDs are not received. fds, err := r.ExtractFDs() if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) } } @@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) { w := s.Writer(true) w.PackFDs(fds...) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err) } } @@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Count the number of FDs. preEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } // Read on the client. @@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { r.EnableFDs(enableSize) } if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Count the new number of FDs. postEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) } // Make sure the FDs are there. fds, err := r.ExtractFDs() if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) } // Make sure they are different from the originals. for i := 0; i < len(fds); i++ { if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) + t.Errorf("Got original fd for index %d, expected different", i) } } @@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Make sure the count is back to normal. finalEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) } } @@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) { } if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) + t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) } } @@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) { want := "bad file descriptor" if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) + t.Errorf("s.GetPeerCred() = %v, want = %s", err, want) } } func TestAcceptClosed(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Fatalf("Close failed, got err %v expected nil", err) } if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } } func TestCloseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Errorf("Close failed, got err %v expected nil", err) } - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) { func TestReleaseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) fd, err := ss.Release() if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) + t.Errorf("Release failed, got err %v expected nil", err) } syscall.Close(fd) - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) { cm.PackFDs(want...) got, err := cm.ExtractFDs() if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) + t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) } } } @@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) { for i := 0; i < b.N; i++ { n, err := server.Read(buf) if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + return } n, err = server.Write(buf) if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + return } } }() diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 67a950444..08519d986 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { // // +stateify savable type Queue struct { - list waiterList `state:"zerovalue"` + list waiterList mu sync.RWMutex `state:"nosave"` } diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 248f77c34..8c73dc5dc 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -38,6 +38,7 @@ go_library( "//pkg/memutil", "//pkg/rand", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/arch:registers_go_proto", "//pkg/sentry/control", @@ -74,6 +75,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/sighandling", "//pkg/sentry/socket/hostinet", + "//pkg/sentry/socket/netfilter", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netlink/route", "//pkg/sentry/socket/netlink/uevent", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 894651519..fdf13c8e1 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -30,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/state" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" @@ -195,7 +196,7 @@ type containerManager struct { // StartRoot will start the root container process. func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error { - log.Debugf("containerManager.StartRoot %q", *cid) + log.Debugf("containerManager.StartRoot, cid: %s", *cid) // Tell the root container to start and wait for the result. cm.startChan <- struct{}{} if err := <-cm.startResultChan; err != nil { @@ -206,13 +207,13 @@ func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error { // Processes retrieves information about processes running in the sandbox. func (cm *containerManager) Processes(cid *string, out *[]*control.Process) error { - log.Debugf("containerManager.Processes: %q", *cid) + log.Debugf("containerManager.Processes, cid: %s", *cid) return control.Processes(cm.l.k, *cid, out) } // Create creates a container within a sandbox. func (cm *containerManager) Create(cid *string, _ *struct{}) error { - log.Debugf("containerManager.Create: %q", *cid) + log.Debugf("containerManager.Create, cid: %s", *cid) return cm.l.createContainer(*cid) } @@ -236,12 +237,11 @@ type StartArgs struct { // Start runs a created container within a sandbox. func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { - log.Debugf("containerManager.Start: %+v", args) - // Validate arguments. if args == nil { return errors.New("start missing arguments") } + log.Debugf("containerManager.Start, cid: %s, args: %+v", args.CID, args) if args.Spec == nil { return errors.New("start arguments missing spec") } @@ -268,27 +268,27 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { } }() if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil { - log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err) + log.Debugf("containerManager.Start failed, cid: %s, args: %+v, err: %v", args.CID, args, err) return err } - log.Debugf("Container %q started", args.CID) + log.Debugf("Container started, cid: %s", args.CID) return nil } // Destroy stops a container if it is still running and cleans up its // filesystem. func (cm *containerManager) Destroy(cid *string, _ *struct{}) error { - log.Debugf("containerManager.destroy %q", *cid) + log.Debugf("containerManager.destroy, cid: %s", *cid) return cm.l.destroyContainer(*cid) } // ExecuteAsync starts running a command on a created or running sandbox. It // returns the PID of the new process. func (cm *containerManager) ExecuteAsync(args *control.ExecArgs, pid *int32) error { - log.Debugf("containerManager.ExecuteAsync: %+v", args) + log.Debugf("containerManager.ExecuteAsync, cid: %s, args: %+v", args.ContainerID, args) tgid, err := cm.l.executeAsync(args) if err != nil { - log.Debugf("containerManager.ExecuteAsync failed: %+v: %v", args, err) + log.Debugf("containerManager.ExecuteAsync failed, cid: %s, args: %+v, err: %v", args.ContainerID, args, err) return err } *pid = int32(tgid) @@ -367,12 +367,20 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { cm.l.k = k // Set up the restore environment. + ctx := k.SupervisorContext() mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints) - renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) - if err != nil { - return fmt.Errorf("creating RestoreEnvironment: %v", err) + if kernel.VFS2Enabled { + ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) + if err != nil { + return fmt.Errorf("configuring filesystem restore: %v", err) + } + } else { + renv, err := mntr.createRestoreEnvironment(cm.l.root.conf) + if err != nil { + return fmt.Errorf("creating RestoreEnvironment: %v", err) + } + fs.SetRestoreEnvironment(*renv) } - fs.SetRestoreEnvironment(*renv) // Prepare to load from the state file. if eps, ok := networkStack.(*netstack.Stack); ok { @@ -399,7 +407,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Load the state. loadOpts := state.LoadOpts{Source: specFile} - if err := loadOpts.Load(k, networkStack, time.NewCalibratedClocks()); err != nil { + if err := loadOpts.Load(ctx, k, networkStack, time.NewCalibratedClocks(), &vfs.CompleteRestoreOptions{}); err != nil { return err } @@ -444,9 +452,9 @@ func (cm *containerManager) Resume(_, _ *struct{}) error { // Wait waits for the init process in the given container. func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error { - log.Debugf("containerManager.Wait") + log.Debugf("containerManager.Wait, cid: %s", *cid) err := cm.l.waitContainer(*cid, waitStatus) - log.Debugf("containerManager.Wait returned, waitStatus: %v: %v", waitStatus, err) + log.Debugf("containerManager.Wait returned, cid: %s, waitStatus: %#x, err: %v", *cid, *waitStatus, err) return err } @@ -461,8 +469,10 @@ type WaitPIDArgs struct { // WaitPID waits for the process with PID 'pid' in the sandbox. func (cm *containerManager) WaitPID(args *WaitPIDArgs, waitStatus *uint32) error { - log.Debugf("containerManager.Wait") - return cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus) + log.Debugf("containerManager.Wait, cid: %s, pid: %d", args.CID, args.PID) + err := cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus) + log.Debugf("containerManager.Wait, cid: %s, pid: %d, waitStatus: %#x, err: %v", args.CID, args.PID, *waitStatus, err) + return err } // SignalDeliveryMode enumerates different signal delivery modes. @@ -519,6 +529,6 @@ type SignalArgs struct { // indicated process, to all processes in the container, or to the foreground // process group. func (cm *containerManager) Signal(args *SignalArgs, _ *struct{}) error { - log.Debugf("containerManager.Signal %+v", args) + log.Debugf("containerManager.Signal: cid: %s, PID: %d, signal: %d, mode: %v", args.CID, args.PID, args.Signo, args.Mode) return cm.l.signal(args.CID, args.PID, args.Signo, args.Mode) } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index ddf288456..6b6ae98d7 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -105,33 +105,28 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name // mandatory mounts that are required by the OCI specification. func compileMounts(spec *specs.Spec) []specs.Mount { // Keep track of whether proc and sys were mounted. - var procMounted, sysMounted bool + var procMounted, sysMounted, devMounted, devptsMounted bool var mounts []specs.Mount - // Always mount /dev. - mounts = append(mounts, specs.Mount{ - Type: devtmpfs.Name, - Destination: "/dev", - }) - - mounts = append(mounts, specs.Mount{ - Type: devpts.Name, - Destination: "/dev/pts", - }) - // Mount all submounts from the spec. for _, m := range spec.Mounts { if !specutils.IsSupportedDevMount(m) { log.Warningf("ignoring dev mount at %q", m.Destination) continue } - mounts = append(mounts, m) switch filepath.Clean(m.Destination) { case "/proc": procMounted = true case "/sys": sysMounted = true + case "/dev": + m.Type = devtmpfs.Name + devMounted = true + case "/dev/pts": + m.Type = devpts.Name + devptsMounted = true } + mounts = append(mounts, m) } // Mount proc and sys even if the user did not ask for it, as the spec @@ -149,6 +144,18 @@ func compileMounts(spec *specs.Spec) []specs.Mount { Destination: "/sys", }) } + if !devMounted { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: devtmpfs.Name, + Destination: "/dev", + }) + } + if !devptsMounted { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: devpts.Name, + Destination: "/dev/pts", + }) + } // The mandatory mounts should be ordered right after the root, in case // there are submounts of these mandatory mounts already in the spec. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 8ad000497..ebdd518d0 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/memutil" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fdimport" @@ -49,6 +50,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/sighandling" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -476,6 +478,10 @@ func (l *Loader) Destroy() { // save/restore. l.k.Release() + // All sentry-created resources should have been released at this point; + // check for reference leaks. + refsvfs2.DoLeakCheck() + // In the success case, stdioFDs and goferFDs will only contain // released/closed FDs that ownership has been passed over to host FDs and // gofer sessions. Close them here in case of failure. @@ -737,7 +743,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn return nil, err } - // Add the HOME enviroment variable if it is not already set. + // Add the HOME environment variable if it is not already set. var envv []string if kernel.VFS2Enabled { envv, err = user.MaybeAddExecUserHomeVFS2(ctx, info.procArgs.MountNamespaceVFS2, @@ -882,7 +888,7 @@ func (l *Loader) destroyContainer(cid string) error { } } - log.Debugf("Container destroyed %q", cid) + log.Debugf("Container destroyed, cid: %s", cid) return nil } @@ -1079,6 +1085,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in // privileges. RawFactory: raw.EndpointFactory{}, UniqueID: uniqueID, + IPTables: netfilter.DefaultLinuxTables(), })} // Enable SACK Recovery. diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index e376f944b..b77b4762e 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -266,7 +266,7 @@ type CreateMountTestcase struct { func createMountTestcases() []*CreateMountTestcase { testCases := []*CreateMountTestcase{ - &CreateMountTestcase{ + { // Only proc. name: "only proc mount", spec: specs.Spec{ @@ -304,11 +304,10 @@ func createMountTestcases() []*CreateMountTestcase { }, }, }, - // /some/deep/path should be mounted, along with /proc, - // /dev, and /sys. + // /some/deep/path should be mounted, along with /proc, /dev, and /sys. expectedPaths: []string{"/some/very/very/deep/path", "/proc", "/dev", "/sys"}, }, - &CreateMountTestcase{ + { // Mounts are nested inside each other. name: "nested mounts", spec: specs.Spec{ @@ -352,7 +351,7 @@ func createMountTestcases() []*CreateMountTestcase { expectedPaths: []string{"/foo", "/foo/bar", "/foo/bar/baz", "/foo/qux", "/foo/qux-quz", "/foo/some/very/very/deep/path", "/proc", "/dev", "/sys"}, }, - &CreateMountTestcase{ + { name: "mount inside /dev", spec: specs.Spec{ Root: &specs.Root{ @@ -395,35 +394,37 @@ func createMountTestcases() []*CreateMountTestcase { }, expectedPaths: []string{"/proc", "/dev", "/dev/fd-foo", "/dev/foo", "/dev/bar", "/sys"}, }, - } - - vfsCase := &CreateMountTestcase{ - name: "mounts inside mandatory mounts", - spec: specs.Spec{ - Root: &specs.Root{ - Path: os.TempDir(), - Readonly: true, - }, - Mounts: []specs.Mount{ - { - Destination: "/proc", - Type: "tmpfs", - }, - { - Destination: "/sys/bar", - Type: "tmpfs", + { + name: "mounts inside mandatory mounts", + spec: specs.Spec{ + Root: &specs.Root{ + Path: os.TempDir(), + Readonly: true, }, - - { - Destination: "/tmp/baz", - Type: "tmpfs", + Mounts: []specs.Mount{ + { + Destination: "/proc", + Type: "tmpfs", + }, + { + Destination: "/sys/bar", + Type: "tmpfs", + }, + { + Destination: "/tmp/baz", + Type: "tmpfs", + }, + { + Destination: "/dev/goo", + Type: "tmpfs", + }, }, }, + expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz", "/dev/goo"}, }, - expectedPaths: []string{"/proc", "/sys", "/sys/bar", "/tmp", "/tmp/baz"}, } - return append(testCases, vfsCase) + return testCases } // Test that MountNamespace can be created with various specs. diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 004da5b40..b157387ef 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -210,6 +210,9 @@ func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *c ReadOnly: c.root.Readonly, GetFilesystemOptions: vfs.GetFilesystemOptions{ Data: strings.Join(data, ","), + InternalData: gofer.InternalFilesystemOptions{ + UniqueID: "/", + }, }, InternalMount: true, } @@ -427,6 +430,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo fsName := m.Type useOverlay := false var data []string + var iopts interface{} // Find filesystem name and FS specific data field. switch m.Type { @@ -451,6 +455,9 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) + iopts = gofer.InternalFilesystemOptions{ + UniqueID: m.Destination, + } // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly @@ -462,7 +469,8 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo opts := &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ - Data: strings.Join(data, ","), + Data: strings.Join(data, ","), + InternalData: iopts, }, InternalMount: true, } @@ -667,3 +675,21 @@ func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Crede } return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds) } + +// configureRestore returns an updated context.Context including filesystem +// state used by restore defined by conf. +func (c *containerMounter) configureRestore(ctx context.Context, conf *config.Config) (context.Context, error) { + fdmap := make(map[string]int) + fdmap["/"] = c.fds.remove() + mounts, err := c.prepareMountsVFS2() + if err != nil { + return ctx, err + } + for i := range c.mounts { + submount := &mounts[i] + if submount.fd >= 0 { + fdmap[submount.Destination] = submount.fd + } + } + return context.WithValue(ctx, gofer.CtxRestoreServerFDMap, fdmap), nil +} diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 56da21584..5bd0afc52 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -21,6 +21,7 @@ import ( "context" "errors" "fmt" + "io" "io/ioutil" "os" "path/filepath" @@ -198,8 +199,13 @@ func LoadPaths(pid string) (map[string]string, error) { } defer f.Close() + return loadPathsHelper(f) +} + +func loadPathsHelper(cgroup io.Reader) (map[string]string, error) { paths := make(map[string]string) - scanner := bufio.NewScanner(f) + + scanner := bufio.NewScanner(cgroup) for scanner.Scan() { // Format: ID:[name=]controller1,controller2:path // Example: 2:cpu,cpuacct:/user.slice @@ -207,6 +213,9 @@ func LoadPaths(pid string) (map[string]string, error) { if len(tokens) != 3 { return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text()) } + if len(tokens[1]) == 0 { + continue + } for _, ctrlr := range strings.Split(tokens[1], ",") { // Remove prefix for cgroups with no controller, eg. systemd. ctrlr = strings.TrimPrefix(ctrlr, "name=") diff --git a/runsc/cgroup/cgroup_test.go b/runsc/cgroup/cgroup_test.go index 4db5ee5c3..9794517a7 100644 --- a/runsc/cgroup/cgroup_test.go +++ b/runsc/cgroup/cgroup_test.go @@ -647,3 +647,83 @@ func TestPids(t *testing.T) { }) } } + +func TestLoadPaths(t *testing.T) { + for _, tc := range []struct { + name string + cgroups string + want map[string]string + err string + }{ + { + name: "abs-path", + cgroups: "0:ctr:/path", + want: map[string]string{"ctr": "/path"}, + }, + { + name: "rel-path", + cgroups: "0:ctr:rel-path", + want: map[string]string{"ctr": "rel-path"}, + }, + { + name: "non-controller", + cgroups: "0:name=systemd:/path", + want: map[string]string{"systemd": "/path"}, + }, + { + name: "empty", + }, + { + name: "multiple", + cgroups: "0:ctr0:/path0\n" + + "1:ctr1:/path1\n" + + "2::/empty\n", + want: map[string]string{ + "ctr0": "/path0", + "ctr1": "/path1", + }, + }, + { + name: "missing-field", + cgroups: "0:nopath\n", + err: "invalid cgroups file", + }, + { + name: "too-many-fields", + cgroups: "0:ctr:/path:extra\n", + err: "invalid cgroups file", + }, + { + name: "multiple-malformed", + cgroups: "0:ctr0:/path0\n" + + "1:ctr1:/path1\n" + + "2:\n", + err: "invalid cgroups file", + }, + } { + t.Run(tc.name, func(t *testing.T) { + r := strings.NewReader(tc.cgroups) + got, err := loadPathsHelper(r) + if len(tc.err) == 0 { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } else if !strings.Contains(err.Error(), tc.err) { + t.Fatalf("Wrong error message, want: *%s*, got: %v", tc.err, err) + } + for key, vWant := range tc.want { + vGot, ok := got[key] + if !ok { + t.Errorf("Missing controller %q", key) + } + if vWant != vGot { + t.Errorf("Wrong controller %q value, want: %q, got: %q", key, vWant, vGot) + } + delete(got, key) + } + for k, v := range got { + t.Errorf("Unexpected controller %q: %q", k, v) + } + }) + } +} diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go index cd419e1aa..2c92e3067 100644 --- a/runsc/cmd/boot.go +++ b/runsc/cmd/boot.go @@ -131,11 +131,11 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } - // Ensure that if there is a panic, all goroutine stacks are printed. - debug.SetTraceback("system") - conf := args[0].(*config.Config) + // Set traceback level + debug.SetTraceback(conf.Traceback) + if b.attached { // Ensure this process is killed after parent process terminates when // attached mode is enabled. In the unfortunate event that the parent diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index 8fe0c427a..c0bc8f064 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -75,7 +75,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } @@ -149,6 +149,9 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa } ws, err := cont.Wait() + if err != nil { + Fatalf("Error waiting for container: %v", err) + } *waitStatus = ws return subcommands.ExitSuccess diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 132198222..609e8231c 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -91,7 +91,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } var err error - c, err = container.Load(conf.RootDir, f.Arg(0)) + c, err = container.LoadAndCheck(conf.RootDir, f.Arg(0)) if err != nil { return Errorf("loading container %q: %v", f.Arg(0), err) } @@ -106,7 +106,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return Errorf("listing containers: %v", err) } for _, id := range ids { - candidate, err := container.Load(conf.RootDir, id) + candidate, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { return Errorf("loading container %q: %v", id, err) } diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go index 4e49deff8..a25637265 100644 --- a/runsc/cmd/delete.go +++ b/runsc/cmd/delete.go @@ -68,7 +68,7 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} func (d *Delete) execute(ids []string, conf *config.Config) error { for _, id := range ids { - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { if os.IsNotExist(err) && d.force { log.Warningf("couldn't find container %q: %v", id, err) diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 25fe2cf1c..3836b7b4e 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -74,7 +74,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } @@ -85,7 +85,12 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa ev, err := c.Event() if err != nil { log.Warningf("Error getting events for container: %v", err) + if evs.stats { + return subcommands.ExitFailure + } } + log.Debugf("Events: %+v", ev) + // err must be preserved because it is used below when breaking // out of the loop. b, err := json.Marshal(ev) @@ -101,11 +106,9 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa if err != nil { return subcommands.ExitFailure } - break + return subcommands.ExitSuccess } time.Sleep(time.Duration(evs.intervalSec) * time.Second) } - - return subcommands.ExitSuccess } diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index 775ed4b43..86c02a22a 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -112,7 +112,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } waitStatus := args[1].(*syscall.WaitStatus) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index 04eee99b2..fe69e2a08 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -69,7 +69,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("it is invalid to specify both --all and --pid") } - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go index f92d6fef9..6907eb16a 100644 --- a/runsc/cmd/list.go +++ b/runsc/cmd/list.go @@ -79,7 +79,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Collect the containers. var containers []*container.Container for _, id := range ids { - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container %q: %v", id, err) } diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go index 0eb1402ed..fe7d4e257 100644 --- a/runsc/cmd/pause.go +++ b/runsc/cmd/pause.go @@ -55,7 +55,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go index bc58c928f..18d7a1436 100644 --- a/runsc/cmd/ps.go +++ b/runsc/cmd/ps.go @@ -60,7 +60,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go index f24823f99..a00928204 100644 --- a/runsc/cmd/resume.go +++ b/runsc/cmd/resume.go @@ -56,7 +56,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 88991b521..f6499cc44 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/specutils" ) // Start implements subcommands.Command for the "start" command. @@ -54,10 +55,16 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } + // Read the spec again here to ensure flag annotations from the spec are + // applied to "conf". + if _, err := specutils.ReadSpec(c.BundleDir, conf); err != nil { + Fatalf("reading spec: %v", err) + } + if err := c.Start(conf); err != nil { Fatalf("starting container: %v", err) } diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go index 2bd2ab9f8..d8a70dd7f 100644 --- a/runsc/cmd/state.go +++ b/runsc/cmd/state.go @@ -57,7 +57,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index 28d0642ed..c1d6aeae2 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -72,7 +72,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/config/config.go b/runsc/config/config.go index f30f79f68..b02d8e2e1 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -37,6 +37,9 @@ type Config struct { // RootDir is the runtime root directory. RootDir string `flag:"root"` + // Traceback changes the Go runtime's traceback level. + Traceback string `flag:"traceback"` + // Debug indicates that debug logging should be enabled. Debug bool `flag:"debug"` diff --git a/runsc/config/flags.go b/runsc/config/flags.go index a5f25cfa2..13d8f1b25 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -29,7 +29,7 @@ import ( var registration sync.Once -// This is the set of flags used to populate Config. +// RegisterFlags registers flags used to populate Config. func RegisterFlags() { registration.Do(func() { // Although these flags are not part of the OCI spec, they are used by @@ -49,6 +49,7 @@ func RegisterFlags() { flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") flag.Bool("alsologtostderr", false, "send log messages to stderr.") flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.") + flag.String("traceback", "system", "golang runtime's traceback level") // Debugging flags: strace related flag.Bool("strace", false, "enable strace.") diff --git a/runsc/container/container.go b/runsc/container/container.go index 63f64ce6e..4aa139c88 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -159,9 +159,9 @@ func loadSandbox(rootDir, id string) ([]*Container, error) { // container to which id unambiguously refers to. Returns ErrNotExist if // container doesn't exist. func Load(rootDir, partialID string) (*Container, error) { - log.Debugf("Load container %q %q", rootDir, partialID) + log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID) if err := validateID(partialID); err != nil { - return nil, fmt.Errorf("validating id: %v", err) + return nil, fmt.Errorf("invalid container id: %v", err) } id, err := findContainerID(rootDir, partialID) @@ -184,22 +184,31 @@ func Load(rootDir, partialID string) (*Container, error) { } return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) } + return c, nil +} + +// LoadAndCheck is similar to Load(), but also checks if the container is still +// running to get an error earlier to the caller. +func LoadAndCheck(rootDir, partialID string) (*Container, error) { + c, err := Load(rootDir, partialID) + if err != nil { + // Preserve error so that callers can distinguish 'not found' errors. + return nil, err + } - // If the status is "Running" or "Created", check that the sandbox - // process still exists, and set it to Stopped if it does not. + // If the status is "Running" or "Created", check that the sandbox/container + // is still running, setting it to Stopped if not. // // This is inherently racy. - if c.Status == Running || c.Status == Created { - // Check if the sandbox process is still running. + switch c.Status { + case Created: if !c.isSandboxRunning() { // Sandbox no longer exists, so this container definitely does not exist. c.changeStatus(Stopped) - } else if c.Status == Running { - // Container state should reflect the actual state of the application, so - // we don't consider gofer process here. - if err := c.SignalContainer(syscall.Signal(0), false); err != nil { - c.changeStatus(Stopped) - } + } + case Running: + if err := c.SignalContainer(syscall.Signal(0), false); err != nil { + c.changeStatus(Stopped) } } @@ -271,7 +280,7 @@ type Args struct { // indicates that an existing Sandbox should be used. The caller must call // Destroy() on the container. func New(conf *config.Config, args Args) (*Container, error) { - log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir) + log.Debugf("Create container, cid: %s, rootDir: %q", args.ID, conf.RootDir) if err := validateID(args.ID); err != nil { return nil, err } @@ -310,7 +319,15 @@ func New(conf *config.Config, args Args) (*Container, error) { // indicate the ID of the sandbox, which is the same as the ID of the // init container in the sandbox. if isRoot(args.Spec) { - log.Debugf("Creating new sandbox for container %q", args.ID) + log.Debugf("Creating new sandbox for container, cid: %s", args.ID) + + if args.Spec.Linux == nil { + args.Spec.Linux = &specs.Linux{} + } + // Don't force the use of cgroups in tests because they lack permission to do so. + if args.Spec.Linux.CgroupsPath == "" && !conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { + args.Spec.Linux.CgroupsPath = "/" + args.ID + } // Create and join cgroup before processes are created to ensure they are // part of the cgroup from the start (and all their children processes). @@ -321,7 +338,13 @@ func New(conf *config.Config, args Args) (*Container, error) { if cg != nil { // If there is cgroup config, install it before creating sandbox process. if err := cg.Install(args.Spec.Linux.Resources); err != nil { - return nil, fmt.Errorf("configuring cgroup: %v", err) + switch { + case errors.Is(err, syscall.EACCES) && conf.Rootless: + log.Warningf("Skipping cgroup configuration in rootless mode: %v", err) + cg = nil + default: + return nil, fmt.Errorf("configuring cgroup: %v", err) + } } } if err := runInCgroup(cg, func() error { @@ -366,10 +389,10 @@ func New(conf *config.Config, args Args) (*Container, error) { if !ok { return nil, fmt.Errorf("no sandbox ID found when creating container") } - log.Debugf("Creating new container %q in sandbox %q", c.ID, sbid) + log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sbid) // Find the sandbox associated with this ID. - sb, err := Load(conf.RootDir, sbid) + sb, err := LoadAndCheck(conf.RootDir, sbid) if err != nil { return nil, err } @@ -399,7 +422,7 @@ func New(conf *config.Config, args Args) (*Container, error) { // Start starts running the containerized process inside the sandbox. func (c *Container) Start(conf *config.Config) error { - log.Debugf("Start container %q", c.ID) + log.Debugf("Start container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -462,7 +485,7 @@ func (c *Container) Start(conf *config.Config) error { unlock.Clean() // Adjust the oom_score_adj for sandbox. This must be done after saveLocked(). - if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil { + if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Spec, c.Saver.RootDir, false); err != nil { return err } @@ -474,7 +497,7 @@ func (c *Container) Start(conf *config.Config) error { // Restore takes a container and replaces its kernel and file system // to restore a container from its state file. func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile string) error { - log.Debugf("Restore container %q", c.ID) + log.Debugf("Restore container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -501,7 +524,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile s // Run is a helper that calls Create + Start + Wait. func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { - log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir) + log.Debugf("Run container, cid: %s, rootDir: %q", args.ID, conf.RootDir) c, err := New(conf, args) if err != nil { return 0, fmt.Errorf("creating container: %v", err) @@ -533,7 +556,7 @@ func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { // Execute runs the specified command in the container. It returns the PID of // the newly created process. func (c *Container) Execute(args *control.ExecArgs) (int32, error) { - log.Debugf("Execute in container %q, args: %+v", c.ID, args) + log.Debugf("Execute in container, cid: %s, args: %+v", c.ID, args) if err := c.requireStatus("execute in", Created, Running); err != nil { return 0, err } @@ -543,7 +566,7 @@ func (c *Container) Execute(args *control.ExecArgs) (int32, error) { // Event returns events for the container. func (c *Container) Event() (*boot.Event, error) { - log.Debugf("Getting events for container %q", c.ID) + log.Debugf("Getting events for container, cid: %s", c.ID) if err := c.requireStatus("get events for", Created, Running, Paused); err != nil { return nil, err } @@ -563,14 +586,19 @@ func (c *Container) SandboxPid() int { // Call to wait on a stopped container is needed to retrieve the exit status // and wait returns immediately. func (c *Container) Wait() (syscall.WaitStatus, error) { - log.Debugf("Wait on container %q", c.ID) - return c.Sandbox.Wait(c.ID) + log.Debugf("Wait on container, cid: %s", c.ID) + ws, err := c.Sandbox.Wait(c.ID) + if err == nil { + // Wait succeeded, container is not running anymore. + c.changeStatus(Stopped) + } + return ws, err } // WaitRootPID waits for process 'pid' in the sandbox's PID namespace and // returns its WaitStatus. func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { - log.Debugf("Wait on PID %d in sandbox %q", pid, c.Sandbox.ID) + log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID) if !c.isSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } @@ -580,7 +608,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { // WaitPID waits for process 'pid' in the container's PID namespace and returns // its WaitStatus. func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { - log.Debugf("Wait on PID %d in container %q", pid, c.ID) + log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID) if !c.isSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } @@ -592,7 +620,7 @@ func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { // SignalContainer returns an error if the container is already stopped. // TODO(b/113680494): Distinguish different error types. func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { - log.Debugf("Signal container %q: %v", c.ID, sig) + log.Debugf("Signal container, cid: %s, signal: %v (%d)", c.ID, sig, sig) // Signaling container in Stopped state is allowed. When all=false, // an error will be returned anyway; when all=true, this allows // sending signal to other processes inside the container even @@ -609,7 +637,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { // SignalProcess sends sig to a specific process in the container. func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { - log.Debugf("Signal process %d in container %q: %v", pid, c.ID, sig) + log.Debugf("Signal process %d in container, cid: %s, signal: %v (%d)", pid, c.ID, sig, sig) if err := c.requireStatus("signal a process inside", Running); err != nil { return err } @@ -623,15 +651,15 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { // container process inside the sandbox. It returns a function that will stop // forwarding signals. func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() { - log.Debugf("Forwarding all signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess) + log.Debugf("Forwarding all signals to container, cid: %s, PIDPID: %d, fgProcess: %t", c.ID, pid, fgProcess) stop := sighandling.StartSignalForwarding(func(sig linux.Signal) { - log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", sig, c.ID, pid, fgProcess) + log.Debugf("Forwarding signal %d to container, cid: %s, PID: %d, fgProcess: %t", sig, c.ID, pid, fgProcess) if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil { log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err) } }) return func() { - log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess) + log.Debugf("Done forwarding signals to container, cid: %s, PID: %d, fgProcess: %t", c.ID, pid, fgProcess) stop() } } @@ -639,7 +667,7 @@ func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() { // Checkpoint sends the checkpoint call to the container. // The statefile will be written to f, the file at the specified image-path. func (c *Container) Checkpoint(f *os.File) error { - log.Debugf("Checkpoint container %q", c.ID) + log.Debugf("Checkpoint container, cid: %s", c.ID) if err := c.requireStatus("checkpoint", Created, Running, Paused); err != nil { return err } @@ -649,7 +677,7 @@ func (c *Container) Checkpoint(f *os.File) error { // Pause suspends the container and its kernel. // The call only succeeds if the container's status is created or running. func (c *Container) Pause() error { - log.Debugf("Pausing container %q", c.ID) + log.Debugf("Pausing container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -660,7 +688,7 @@ func (c *Container) Pause() error { } if err := c.Sandbox.Pause(c.ID); err != nil { - return fmt.Errorf("pausing container: %v", err) + return fmt.Errorf("pausing container %q: %v", c.ID, err) } c.changeStatus(Paused) return c.saveLocked() @@ -669,7 +697,7 @@ func (c *Container) Pause() error { // Resume unpauses the container and its kernel. // The call only succeeds if the container's status is paused. func (c *Container) Resume() error { - log.Debugf("Resuming container %q", c.ID) + log.Debugf("Resuming container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -708,7 +736,7 @@ func (c *Container) Processes() ([]*control.Process, error) { // Destroy stops all processes and frees all resources associated with the // container. func (c *Container) Destroy() error { - log.Debugf("Destroy container %q", c.ID) + log.Debugf("Destroy container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -745,14 +773,12 @@ func (c *Container) Destroy() error { c.changeStatus(Stopped) // Adjust oom_score_adj for the sandbox. This must be done after the container - // is stopped and the directory at c.Root is removed. Adjustment can be - // skipped if the root container is exiting, because it brings down the entire - // sandbox. + // is stopped and the directory at c.Root is removed. // // Use 'sb' to tell whether it has been executed before because Destroy must // be idempotent. - if sb != nil && !isRoot(c.Spec) { - if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil { + if sb != nil { + if err := adjustSandboxOOMScoreAdj(sb, c.Spec, c.Saver.RootDir, true); err != nil { errs = append(errs, err.Error()) } } @@ -781,7 +807,7 @@ func (c *Container) Destroy() error { // // Precondition: container must be locked with container.lock(). func (c *Container) saveLocked() error { - log.Debugf("Save container %q", c.ID) + log.Debugf("Save container, cid: %s", c.ID) if err := c.Saver.saveLocked(c); err != nil { return fmt.Errorf("saving container metadata: %v", err) } @@ -795,7 +821,7 @@ func (c *Container) stop() error { var cgroup *cgroup.Cgroup if c.Sandbox != nil { - log.Debugf("Destroying container %q", c.ID) + log.Debugf("Destroying container, cid: %s", c.ID) if err := c.Sandbox.DestroyContainer(c.ID); err != nil { return fmt.Errorf("destroying container %q: %v", c.ID, err) } @@ -809,7 +835,7 @@ func (c *Container) stop() error { // Try killing gofer if it does not exit with container. if c.GoferPid != 0 { - log.Debugf("Killing gofer for container %q, PID: %d", c.ID, c.GoferPid) + log.Debugf("Killing gofer for container, cid: %s, PID: %d", c.ID, c.GoferPid) if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { // The gofer may already be stopped, log the error. log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err) @@ -1082,7 +1108,13 @@ func (c *Container) adjustGoferOOMScoreAdj() error { // TODO(gvisor.dev/issue/238): This call could race with other containers being // created at the same time and end up setting the wrong oom_score_adj to the // sandbox. Use rpc client to synchronize. -func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error { +func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, spec *specs.Spec, rootDir string, destroy bool) error { + // Adjustment can be skipped if the root container is exiting, because it + // brings down the entire sandbox. + if isRoot(spec) && destroy { + return nil + } + containers, err := loadSandbox(rootDir, s.ID) if err != nil { return fmt.Errorf("loading sandbox containers: %v", err) @@ -1096,53 +1128,34 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Get the lowest score for all containers. var lowScore int scoreFound := false - if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified { - // This is a single-container sandbox. Set the oom_score_adj to - // the value specified in the OCI bundle. - if containers[0].Spec.Process.OOMScoreAdj != nil { - scoreFound = true - lowScore = *containers[0].Spec.Process.OOMScoreAdj + for _, container := range containers { + // Special multi-container support for CRI. Ignore the root container when + // calculating oom_score_adj for the sandbox because it is the + // infrastructure (pause) container and always has a very low oom_score_adj. + // + // We will use OOMScoreAdj in the single-container case where the + // containerd container-type annotation is not present. + if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { + continue } - } else { - for _, container := range containers { - // Special multi-container support for CRI. Ignore the root - // container when calculating oom_score_adj for the sandbox because - // it is the infrastructure (pause) container and always has a very - // low oom_score_adj. - // - // We will use OOMScoreAdj in the single-container case where the - // containerd container-type annotation is not present. - if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { - continue - } - if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { - scoreFound = true - lowScore = *container.Spec.Process.OOMScoreAdj - } + if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { + scoreFound = true + lowScore = *container.Spec.Process.OOMScoreAdj } } // If the container is destroyed and remaining containers have no - // oomScoreAdj specified then we must revert to the oom_score_adj of the - // parent process. + // oomScoreAdj specified then we must revert to the original oom_score_adj + // saved with the root container. if !scoreFound && destroy { - ppid, err := specutils.GetParentPid(s.Pid) - if err != nil { - return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err) - } - pScore, err := specutils.GetOOMScoreAdj(ppid) - if err != nil { - return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err) - } - + lowScore = containers[0].Sandbox.OriginalOOMScoreAdj scoreFound = true - lowScore = pScore } - // Only set oom_score_adj if one of the containers has oom_score_adj set - // in the OCI bundle. If not, we need to inherit the parent process's - // oom_score_adj. + // Only set oom_score_adj if one of the containers has oom_score_adj set. If + // not, oom_score_adj is inherited from the parent process. + // // See: https://github.com/opencontainers/runtime-spec/blob/master/config.md#linux-process if !scoreFound { return nil diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index cc188f45b..fa99e403a 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -364,7 +364,7 @@ func TestLifecycle(t *testing.T) { defer c.Destroy() // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -387,7 +387,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -428,7 +428,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -451,7 +451,7 @@ func TestLifecycle(t *testing.T) { } // Loading the container by id should fail. - if _, err = Load(rootDir, args.ID); err == nil { + if _, err = LoadAndCheck(rootDir, args.ID); err == nil { t.Errorf("expected loading destroyed container to fail, but it did not") } }) @@ -1738,7 +1738,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { cids[2]: cids[2], } for shortid, longid := range unambiguous { - if _, err := Load(rootDir, shortid); err != nil { + if _, err := LoadAndCheck(rootDir, shortid); err != nil { t.Errorf("%q should resolve to %q: %v", shortid, longid, err) } } @@ -1749,7 +1749,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { "ba", } for _, shortid := range ambiguous { - if s, err := Load(rootDir, shortid); err == nil { + if s, err := LoadAndCheck(rootDir, shortid); err == nil { t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID) } } @@ -1976,11 +1976,11 @@ func doDestroyNotStartedTest(t *testing.T, vfs2 bool) { // TestDestroyStarting attempts to force a race between start and destroy. func TestDestroyStarting(t *testing.T) { - doDestroyNotStartedTest(t, false) + doDestroyStartingTest(t, false) } func TestDestroyStartedVFS2(t *testing.T) { - doDestroyNotStartedTest(t, true) + doDestroyStartingTest(t, true) } func doDestroyStartingTest(t *testing.T, vfs2 bool) { @@ -2007,7 +2007,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := Load(rootDir, args.ID) + startCont, err := LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 850e80290..cadc63bf3 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -15,6 +15,7 @@ package container import ( + "encoding/json" "fmt" "io/ioutil" "math" @@ -762,7 +763,7 @@ func TestMultiContainerKillAll(t *testing.T) { // processes still running inside. containers[1].SignalContainer(syscall.SIGKILL, false) op := func() error { - c, err := Load(conf.RootDir, ids[1]) + c, err := LoadAndCheck(conf.RootDir, ids[1]) if err != nil { return err } @@ -776,7 +777,7 @@ func TestMultiContainerKillAll(t *testing.T) { } } - c, err := Load(conf.RootDir, ids[1]) + c, err := LoadAndCheck(conf.RootDir, ids[1]) if err != nil { t.Fatalf("failed to load child container %q: %v", c.ID, err) } @@ -899,7 +900,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := Load(rootDir, ids[i]) + startCont, err := LoadAndCheck(rootDir, ids[i]) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -1766,3 +1767,72 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { }) } } + +func TestMultiContainerEvent(t *testing.T) { + conf := testutil.TestConfig(t) + rootDir, cleanup, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer cleanup() + conf.RootDir = rootDir + + // Setup the containers. + sleep := []string{"/bin/sleep", "100"} + quick := []string{"/bin/true"} + podSpec, ids := createSpecs(sleep, sleep, quick) + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + for _, cont := range containers { + t.Logf("Running containerd %s", cont.ID) + } + + // Wait for last container to stabilize the process count that is checked + // further below. + if ws, err := containers[2].Wait(); err != nil || ws != 0 { + t.Fatalf("Container.Wait, status: %v, err: %v", ws, err) + } + + // Check events for running containers. + for _, cont := range containers[:2] { + evt, err := cont.Event() + if err != nil { + t.Errorf("Container.Events(): %v", err) + } + if want := "stats"; evt.Type != want { + t.Errorf("Wrong event type, want: %s, got :%s", want, evt.Type) + } + if cont.ID != evt.ID { + t.Errorf("Wrong container ID, want: %s, got :%s", cont.ID, evt.ID) + } + // Event.Data is an interface, so it comes from the wire was + // map[string]string. Marshal and unmarshall again to the correc type. + data, err := json.Marshal(evt.Data) + if err != nil { + t.Fatalf("invalid event data: %v", err) + } + var stats boot.Stats + if err := json.Unmarshal(data, &stats); err != nil { + t.Fatalf("invalid event data: %v", err) + } + // One process per remaining container. + if want := uint64(2); stats.Pids.Current != want { + t.Errorf("Wrong number of PIDs, want: %d, got :%d", want, stats.Pids.Current) + } + } + + // Check that stop and destroyed containers return error. + if err := containers[1].Destroy(); err != nil { + t.Fatalf("container.Destroy: %v", err) + } + for _, cont := range containers[1:] { + _, err := cont.Event() + if err == nil { + t.Errorf("Container.Events() should have failed, cid:%s, state: %v", cont.ID, cont.Status) + } + } +} diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index c4309feb3..4a4110477 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -66,6 +66,10 @@ type Sandbox struct { // Cgroup has the cgroup configuration for the sandbox. Cgroup *cgroup.Cgroup `json:"cgroup"` + // OriginalOOMScoreAdj stores the value of oom_score_adj when the sandbox + // started, before it may be modified. + OriginalOOMScoreAdj int `json:"originalOomScoreAdj"` + // child is set if a sandbox process is a child of the current process. // // This field isn't saved to json, because only a creator of sandbox @@ -739,6 +743,11 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn } return err } + s.OriginalOOMScoreAdj, err = specutils.GetOOMScoreAdj(cmd.Process.Pid) + if err != nil { + return err + } + s.child = true s.Pid = cmd.Process.Pid log.Infof("Sandbox started, PID: %d", s.Pid) @@ -1133,11 +1142,11 @@ func (s *Sandbox) DestroyContainer(cid string) error { func (s *Sandbox) destroyContainer(cid string) error { if s.IsRootContainer(cid) { - log.Debugf("Destroying root container %q by destroying sandbox", cid) + log.Debugf("Destroying root container by destroying sandbox, cid: %s", cid) return s.destroy() } - log.Debugf("Destroying container %q in sandbox %q", cid, s.ID) + log.Debugf("Destroying container, cid: %s, sandbox: %s", cid, s.ID) conn, err := s.sandboxConnect() if err != nil { return err diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 0392e3e83..fdbba1832 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -19,6 +19,7 @@ package specutils import ( "encoding/json" "fmt" + "io" "io/ioutil" "os" "path" @@ -169,7 +170,7 @@ func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) { // ReadSpecFromFile reads an OCI runtime spec from the given File, and // normalizes all relative paths into absolute by prepending the bundle dir. func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) { - if _, err := specFile.Seek(0, os.SEEK_SET); err != nil { + if _, err := specFile.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err) } specBytes, err := ioutil.ReadAll(specFile) @@ -344,15 +345,9 @@ func IsSupportedDevMount(m specs.Mount) bool { var existingDevices = []string{ "/dev/fd", "/dev/stdin", "/dev/stdout", "/dev/stderr", "/dev/null", "/dev/zero", "/dev/full", "/dev/random", - "/dev/urandom", "/dev/shm", "/dev/pts", "/dev/ptmx", + "/dev/urandom", "/dev/shm", "/dev/ptmx", } dst := filepath.Clean(m.Destination) - if dst == "/dev" { - // OCI spec uses many different mounts for the things inside of '/dev'. We - // have a single mount at '/dev' that is always mounted, regardless of - // whether it was asked for, as the spec says we SHOULD. - return false - } for _, dev := range existingDevices { if dst == dev || strings.HasPrefix(dst, dev+"/") { return false @@ -425,7 +420,7 @@ func Mount(src, dst, typ string, flags uint32) error { // Special case, as there is no source directory for proc mounts. isDir = true } else if fi, err := os.Stat(src); err != nil { - return fmt.Errorf("Stat(%q) failed: %v", src, err) + return fmt.Errorf("stat(%q) failed: %v", src, err) } else { isDir = fi.IsDir() } @@ -433,25 +428,25 @@ func Mount(src, dst, typ string, flags uint32) error { if isDir { // Create the destination directory. if err := os.MkdirAll(dst, 0777); err != nil { - return fmt.Errorf("Mkdir(%q) failed: %v", dst, err) + return fmt.Errorf("mkdir(%q) failed: %v", dst, err) } } else { // Create the parent destination directory. parent := path.Dir(dst) if err := os.MkdirAll(parent, 0777); err != nil { - return fmt.Errorf("Mkdir(%q) failed: %v", parent, err) + return fmt.Errorf("mkdir(%q) failed: %v", parent, err) } // Create the destination file if it does not exist. f, err := os.OpenFile(dst, syscall.O_CREAT, 0777) if err != nil { - return fmt.Errorf("Open(%q) failed: %v", dst, err) + return fmt.Errorf("open(%q) failed: %v", dst, err) } f.Close() } // Do the mount. if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil { - return fmt.Errorf("Mount(%q, %q, %d) failed: %v", src, dst, flags, err) + return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err) } return nil } @@ -486,35 +481,6 @@ func GetOOMScoreAdj(pid int) (int, error) { return strconv.Atoi(strings.TrimSpace(string(data))) } -// GetParentPid gets the parent process ID of the specified PID. -func GetParentPid(pid int) (int, error) { - data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) - if err != nil { - return 0, err - } - - var cpid string - var name string - var state string - var ppid int - // Parse after the binary name. - _, err = fmt.Sscanf(string(data), - "%v %v %v %d", - // cpid is ignored. - &cpid, - // name is ignored. - &name, - // state is ignored. - &state, - &ppid) - - if err != nil { - return 0, err - } - - return ppid, nil -} - // EnvVar looks for a varible value in the env slice assuming the following // format: "NAME=VALUE". func EnvVar(env []string, name string) (string, bool) { diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD index 32c139204..7dfd4b693 100644 --- a/test/benchmarks/base/BUILD +++ b/test/benchmarks/base/BUILD @@ -13,7 +13,7 @@ go_library( go_test( name = "base_test", - size = "large", + size = "enormous", srcs = [ "size_test.go", "startup_test.go", diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index 32bf2a992..d3e5efd4f 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -441,9 +441,20 @@ func (FilterOutputDestination) Name() string { // ContainerAction implements TestCase.ContainerAction. func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { - rules := [][]string{ - {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, - {"-P", "OUTPUT", "DROP"}, + var rules [][]string + if ipv6 { + rules = [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + // Allow solicited node multicast addresses so we can send neighbor + // solicitations. + {"-A", "OUTPUT", "-d", "ff02::1:ff00:0/104", "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } + } else { + rules = [][]string{ + {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, + {"-P", "OUTPUT", "DROP"}, + } } if err := filterTableRules(ipv6, rules); err != nil { return err diff --git a/test/iptables/nat.go b/test/iptables/nat.go index dd9a18339..b98d99fb8 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -577,11 +577,18 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. connCh := make(chan int) errCh := make(chan error) go func() { - connFD, _, err := syscall.Accept(sockfd) - if err != nil { - errCh <- err + for { + connFD, _, err := syscall.Accept(sockfd) + if errors.Is(err, syscall.EINTR) { + continue + } + if err != nil { + errCh <- err + return + } + connCh <- connFD + return } - connCh <- connFD }() // Wait for accept() to return or for the context to finish. diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go index eecfe0730..006988896 100644 --- a/test/packetimpact/netdevs/netdevs.go +++ b/test/packetimpact/netdevs/netdevs.go @@ -40,7 +40,7 @@ var ( deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`) linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) - inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) + inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-F:/]+)`) ) // ParseDevices parses the output from `ip addr show` into a map from device diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 1546d0d51..c03c2c62c 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -252,6 +252,9 @@ ALL_TESTS = [ expect_netstack_failure = True, ), PacketimpactTestInfo( + name = "ipv4_fragment_reassembly", + ), + PacketimpactTestInfo( name = "ipv6_fragment_reassembly", ), PacketimpactTestInfo( diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index a90046f69..8fa585804 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -839,6 +839,61 @@ func (conn *TCPIPv4) Drain(t *testing.T) { conn.sniffer.Drain(t) } +// IPv4Conn maintains the state for all the layers in a IPv4 connection. +type IPv4Conn Connection + +// NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults. +func NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make EtherState: %s", err) + } + ipv4State, err := newIPv4State(outgoingIPv4, incomingIPv4) + if err != nil { + t.Fatalf("can't make IPv4State: %s", err) + } + + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return IPv4Conn{ + layerStates: []layerState{etherState, ipv4State}, + injector: injector, + sniffer: sniffer, + } +} + +// Send sends a frame with ipv4 overriding the IPv4 layer defaults and +// additionalLayers added after it. +func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(c).send(t, Layers{&ipv4}, additionalLayers...) +} + +// Close cleans up any resources held. +func (c *IPv4Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(c).Close(t) +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (c *IPv4Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(c).ExpectFrame(t, frame, timeout) +} + // IPv6Conn maintains the state for all the layers in a IPv6 connection. type IPv6Conn Connection diff --git a/test/packetimpact/testbench/dut_client.go b/test/packetimpact/testbench/dut_client.go index d0e68c5da..0fc3d97b4 100644 --- a/test/packetimpact/testbench/dut_client.go +++ b/test/packetimpact/testbench/dut_client.go @@ -19,7 +19,7 @@ import ( pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" ) -// PosixClient is a gRPC client for the Posix service. +// POSIXClient is a gRPC client for the Posix service. type POSIXClient pb.PosixClient // NewPOSIXClient makes a new gRPC client for the POSIX service. diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index a35562ca8..af7a2ba4e 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -879,6 +879,9 @@ type ICMPv4 struct { Type *header.ICMPv4Type Code *header.ICMPv4Code Checksum *uint16 + Ident *uint16 + Sequence *uint16 + Payload []byte } func (l *ICMPv4) String() string { @@ -887,7 +890,7 @@ func (l *ICMPv4) String() string { // ToBytes implements Layer.ToBytes. func (l *ICMPv4) ToBytes() ([]byte, error) { - b := make([]byte, header.ICMPv4MinimumSize) + b := make([]byte, header.ICMPv4MinimumSize+len(l.Payload)) h := header.ICMPv4(b) if l.Type != nil { h.SetType(*l.Type) @@ -895,15 +898,33 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { if l.Code != nil { h.SetCode(*l.Code) } + if copied := copy(h.Payload(), l.Payload); copied != len(l.Payload) { + panic(fmt.Sprintf("wrong number of bytes copied into h.Payload(): got = %d, want = %d", len(h.Payload()), len(l.Payload))) + } + if l.Ident != nil { + h.SetIdent(*l.Ident) + } + if l.Sequence != nil { + h.SetSequence(*l.Sequence) + } + + // The checksum must be handled last because the ICMPv4 header fields are + // included in the computation. if l.Checksum != nil { h.SetChecksum(*l.Checksum) - return h, nil - } - payload, err := payload(l) - if err != nil { - return nil, err + } else { + // Compute the checksum based on the ICMPv4.Payload and also the subsequent + // layers. + payload, err := payload(l) + if err != nil { + return nil, err + } + var vv buffer.VectorisedView + vv.AppendView(buffer.View(l.Payload)) + vv.Append(payload) + h.SetChecksum(header.ICMPv4Checksum(h, vv)) } - h.SetChecksum(header.ICMPv4Checksum(h, payload)) + return h, nil } @@ -915,8 +936,11 @@ func parseICMPv4(b []byte) (Layer, layerParser) { Type: ICMPv4Type(h.Type()), Code: ICMPv4Code(h.Code()), Checksum: Uint16(h.Checksum()), + Ident: Uint16(h.Ident()), + Sequence: Uint16(h.Sequence()), + Payload: h.Payload(), } - return &icmpv4, parsePayload + return &icmpv4, nil } func (l *ICMPv4) match(other Layer) bool { diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 8c2de5a9f..c30c77a17 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -298,6 +298,18 @@ packetimpact_testbench( ) packetimpact_testbench( + name = "ipv4_fragment_reassembly", + srcs = ["ipv4_fragment_reassembly_test.go"], + deps = [ + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( name = "ipv6_fragment_reassembly", srcs = ["ipv6_fragment_reassembly_test.go"], deps = [ @@ -305,6 +317,7 @@ packetimpact_testbench( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go new file mode 100644 index 000000000..65c0df140 --- /dev/null +++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go @@ -0,0 +1,142 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_fragment_reassembly_test + +import ( + "flag" + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +type fragmentInfo struct { + offset uint16 + size uint16 + more uint8 +} + +func TestIPv4FragmentReassembly(t *testing.T) { + const fragmentID = 42 + icmpv4ProtoNum := uint8(header.ICMPv4ProtocolNumber) + + tests := []struct { + description string + ipPayloadLen int + fragments []fragmentInfo + expectReply bool + }{ + { + description: "basic reassembly", + ipPayloadLen: 2000, + fragments: []fragmentInfo{ + {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments}, + {offset: 1000, size: 1000, more: 0}, + }, + expectReply: true, + }, + { + description: "out of order fragments", + ipPayloadLen: 2000, + fragments: []fragmentInfo{ + {offset: 1000, size: 1000, more: 0}, + {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments}, + }, + expectReply: true, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{}) + defer conn.Close(t) + + data := make([]byte, test.ipPayloadLen) + icmp := header.ICMPv4(data[:header.ICMPv4MinimumSize]) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetSequence(0) + icmp.SetIdent(0) + originalPayload := data[header.ICMPv4MinimumSize:] + if _, err := rand.Read(originalPayload); err != nil { + t.Fatalf("rand.Read: %s", err) + } + cksum := header.ICMPv4Checksum( + icmp, + buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), + ) + icmp.SetChecksum(cksum) + + for _, fragment := range test.fragments { + conn.Send(t, + testbench.IPv4{ + Protocol: &icmpv4ProtoNum, + FragmentOffset: testbench.Uint16(fragment.offset), + Flags: testbench.Uint8(fragment.more), + ID: testbench.Uint16(fragmentID), + }, + &testbench.Payload{ + Bytes: data[fragment.offset:][:fragment.size], + }) + } + + var bytesReceived int + reassembledPayload := make([]byte, test.ipPayloadLen) + for { + incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv4{}, + &testbench.ICMPv4{}, + }, time.Second) + if err != nil { + // Either an unexpected frame was received, or none at all. + if bytesReceived < test.ipPayloadLen { + t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err) + } + break + } + if !test.expectReply { + t.Fatalf("unexpected reply received:\n%s", incomingFrame) + } + ipPayload, err := incomingFrame[2 /* ICMPv4 */].ToBytes() + if err != nil { + t.Fatalf("failed to parse ICMPv4 header: incomingPacket[2].ToBytes() = (_, %s)", err) + } + offset := *incomingFrame[1 /* IPv4 */].(*testbench.IPv4).FragmentOffset + if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) { + t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload)) + } + bytesReceived += len(ipPayload) + } + + if test.expectReply { + if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv4MinimumSize:]); diff != "" { + t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go index a24c85566..4a29de688 100644 --- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -15,154 +15,137 @@ package ipv6_fragment_reassembly_test import ( - "bytes" - "encoding/binary" - "encoding/hex" "flag" + "math/rand" "net" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) -const ( - // The payload length for the first fragment we send. This number - // is a multiple of 8 near 750 (half of 1500). - firstPayloadLength = 752 - // The ID field for our outgoing fragments. - fragmentID = 1 - // A node must be able to accept a fragmented packet that, - // after reassembly, is as large as 1500 octets. - reassemblyCap = 1500 -) - func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestIPv6FragmentReassembly(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close(t) - - firstPayloadToSend := make([]byte, firstPayloadLength) - for i := range firstPayloadToSend { - firstPayloadToSend[i] = 'A' - } - - secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize - secondPayloadToSend := firstPayloadToSend[:secondPayloadLength] - - icmpv6EchoPayload := make([]byte, 4) - binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0) - binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0) - icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...) - - lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) - rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) - icmpv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - Payload: icmpv6EchoPayload, - } - icmpv6Bytes, err := icmpv6.ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6: %s", err) - } - cksum := header.ICMPv6Checksum( - header.ICMPv6(icmpv6Bytes), - lIP, - rIP, - buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), - ) - - conn.Send(t, testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - FragmentOffset: testbench.Uint16(0), - MoreFragments: testbench.Bool(true), - Identification: testbench.Uint32(fragmentID), - }, - &testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - Payload: icmpv6EchoPayload, - Checksum: &cksum, - }) +type fragmentInfo struct { + offset uint16 + size uint16 + more bool +} +func TestIPv6FragmentReassembly(t *testing.T) { + const fragmentID = 42 icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) - conn.Send(t, testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - NextHeader: &icmpv6ProtoNum, - FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), - MoreFragments: testbench.Bool(false), - Identification: testbench.Uint32(fragmentID), + tests := []struct { + description string + ipPayloadLen int + fragments []fragmentInfo + expectReply bool + }{ + { + description: "basic reassembly", + ipPayloadLen: 1500, + fragments: []fragmentInfo{ + {offset: 0, size: 760, more: true}, + {offset: 760, size: 740, more: false}, + }, + expectReply: true, }, - &testbench.Payload{ - Bytes: secondPayloadToSend, - }) - - gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ - &testbench.Ether{}, - &testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - FragmentOffset: testbench.Uint16(0), - MoreFragments: testbench.Bool(true), + { + description: "out of order fragments", + ipPayloadLen: 3000, + fragments: []fragmentInfo{ + {offset: 0, size: 1024, more: true}, + {offset: 2048, size: 952, more: false}, + {offset: 1024, size: 1024, more: true}, + }, + expectReply: true, }, - &testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - }, - }, time.Second) - if err != nil { - t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err) } - id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification - gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6: %s", err) - } - icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:] - receivedLen := len(icmpPayload) - wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen - wantFirstPayload := make([]byte, receivedLen) - for i := range wantFirstPayload { - wantFirstPayload[i] = 'A' - } - wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen] - if !bytes.Equal(icmpPayload, wantFirstPayload) { - t.Fatalf("received unexpected payload, got: %s, want: %s", - hex.Dump(icmpPayload), - hex.Dump(wantFirstPayload)) - } - - gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ - &testbench.Ether{}, - &testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - NextHeader: &icmpv6ProtoNum, - FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)), - MoreFragments: testbench.Bool(false), - Identification: &id, - }, - &testbench.ICMPv6{}, - }, time.Second) - if err != nil { - t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err) - } - secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err) - } - if !bytes.Equal(secondPayload, wantSecondPayload) { - t.Fatalf("received unexpected payload, got: %s, want: %s", - hex.Dump(secondPayload), - hex.Dump(wantSecondPayload)) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + + lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) + + data := make([]byte, test.ipPayloadLen) + icmp := header.ICMPv6(data[:header.ICMPv6HeaderSize]) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + originalPayload := data[header.ICMPv6HeaderSize:] + if _, err := rand.Read(originalPayload); err != nil { + t.Fatalf("rand.Read: %s", err) + } + + cksum := header.ICMPv6Checksum( + icmp, + lIP, + rIP, + buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), + ) + icmp.SetChecksum(cksum) + + for _, fragment := range test.fragments { + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16(fragment.offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + MoreFragments: testbench.Bool(fragment.more), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.Payload{ + Bytes: data[fragment.offset:][:fragment.size], + }) + } + + var bytesReceived int + reassembledPayload := make([]byte, test.ipPayloadLen) + for { + incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{}, + &testbench.ICMPv6{}, + }, time.Second) + if err != nil { + // Either an unexpected frame was received, or none at all. + if bytesReceived < test.ipPayloadLen { + t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err) + } + break + } + if !test.expectReply { + t.Fatalf("unexpected reply received:\n%s", incomingFrame) + } + ipPayload, err := incomingFrame[3 /* ICMPv6 */].ToBytes() + if err != nil { + t.Fatalf("failed to parse ICMPv6 header: incomingPacket[3].ToBytes() = (_, %s)", err) + } + offset := *incomingFrame[2 /* IPv6FragmentExtHdr */].(*testbench.IPv6FragmentExtHdr).FragmentOffset + offset *= header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) { + t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload)) + } + bytesReceived += len(ipPayload) + } + + if test.expectReply { + if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv6HeaderSize:]); diff != "" { + t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + } + }) } } diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go index 2f57dff19..8a1fe1279 100644 --- a/test/packetimpact/tests/tcp_network_unreachable_test.go +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -74,7 +74,9 @@ func TestTCPSynSentUnreachable(t *testing.T) { } var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{ Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), - Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable), + } + layers = append(layers, &icmpv4, ip, tcp) rawConn.SendFrameStateless(t, layers) diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 4243eb59e..0dcc0fdea 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -40,11 +40,7 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid()) if err != nil { t.Fatalf("getting parent oom_score_adj: %v", err) } @@ -122,11 +118,7 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid()) if err != nil { t.Fatalf("getting parent oom_score_adj: %v", err) } diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 9b5994d59..4992147d4 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -97,10 +97,10 @@ def _syscall_test( # we figure out how to request ipv4 sockets on Guitar machines. if network == "host": tags.append("noguitar") - tags.append("block-network") # Disable off-host networking. tags.append("requires-net:loopback") + tags.append("block-network") # gotsan makes sense only if tests are running in gVisor. if platform == "native": diff --git a/test/runner/runner.go b/test/runner/runner.go index 22d535f8d..7ab2c3edf 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -53,6 +53,9 @@ var ( runscPath = flag.String("runsc", "", "path to runsc binary") addUDSTree = flag.Bool("add-uds-tree", false, "expose a tree of UDS utilities for use in tests") + // TODO(gvisor.dev/issue/4572): properly support leak checking for runsc, and + // set to true as the default for the test runner. + leakCheck = flag.Bool("leak-check", false, "check for reference leaks") ) // runTestCaseNative runs the test case directly on the host machine. @@ -174,6 +177,9 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { if *addUDSTree { args = append(args, "-fsgofer-host-uds") } + if *leakCheck { + args = append(args, "-ref-leak-mode=log-names") + } testLogDir := "" if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { diff --git a/test/runtimes/exclude/java11.csv b/test/runtimes/exclude/java11.csv index d978baca7..e41441374 100644 --- a/test/runtimes/exclude/java11.csv +++ b/test/runtimes/exclude/java11.csv @@ -144,6 +144,7 @@ jdk/jfr/cmd/TestSplit.java,,java.lang.RuntimeException: 'Missing file' missing f jdk/jfr/cmd/TestSummary.java,,java.lang.RuntimeException: 'Missing file' missing from stdout/stderr jdk/jfr/event/compiler/TestCompilerStats.java,,java.lang.RuntimeException: Field nmetodsSize not in event jdk/jfr/event/metadata/TestDefaultConfigurations.java,,Setting 'threshold' in event 'jdk.SecurityPropertyModification' was not configured in the configuration 'default' +jdk/jfr/event/oldobject/TestLargeRootSet.java,,Flaky - `main' threw exception: java.lang.RuntimeException: Could not find root object jdk/jfr/event/runtime/TestActiveSettingEvent.java,,java.lang.Exception: Could not find setting with name jdk.X509Validation#threshold jdk/jfr/event/runtime/TestModuleEvents.java,,java.lang.RuntimeException: assertEquals: expected jdk.proxy1 to equal java.base jdk/jfr/event/runtime/TestNetworkUtilizationEvent.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index ba993814f..c4e7917ec 100644 --- a/test/runtimes/exclude/nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -1,31 +1,22 @@ test name,bug id,comment async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29 -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-napi.js,, +benchmark/test-benchmark-fs.js,,Broken test +benchmark/test-benchmark-napi.js,,Broken test doctool/test-make-doc.js,b/68848110,Expected to fail. internet/test-dgram-multicast-set-interface-lo.js,b/162798882, -internet/test-doctool-versions.js,, -internet/test-uv-threadpool-schedule.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, +internet/test-doctool-versions.js,,Broken test +internet/test-uv-threadpool-schedule.js,,Broken test parallel/test-dgram-bind-fd.js,b/132447356, parallel/test-dgram-socket-buffer-size.js,b/68847921, parallel/test-dns-channel-timeout.js,b/161893056, -parallel/test-fs-access.js,, -parallel/test-fs-watchfile.js,,Flaky - File already exists error -parallel/test-fs-write-stream.js,b/166819807,Flaky -parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky -parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky -parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 +parallel/test-fs-access.js,,Broken test +parallel/test-fs-watchfile.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream-double-close.js,b/166819807,Flaky - VFS1 only +parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky - VFS1 only +parallel/test-http-writable-true-after-close.js,b/171301436,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 parallel/test-os.js,b/63997097, -parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE -parallel/test-process-uid-gid.js,, -parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE -parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE -parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE +parallel/test-process-uid-gid.js,,Does not work inside Docker with gid nobody pseudo-tty/test-assert-colors.js,b/162801321, pseudo-tty/test-assert-no-color.js,b/162801321, pseudo-tty/test-assert-position-indicator.js,b/162801321, @@ -48,11 +39,7 @@ pseudo-tty/test-tty-stdout-resize.js,b/162801321, pseudo-tty/test-tty-stream-constructors.js,b/162801321, pseudo-tty/test-tty-window-size.js,b/162801321, pseudo-tty/test-tty-wrap.js,b/162801321, -pummel/test-heapdump-http2.js,,Flaky -pummel/test-net-pingpong.js,, +pummel/test-net-pingpong.js,,Broken test pummel/test-vm-memleak.js,b/162799436, -pummel/test-watch-file.js,,Flaky - Timeout -sequential/test-child-process-pass-fd.js,b/63926391,Flaky -sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE -sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout -tick-processor/test-tick-processor-builtin.js,, +pummel/test-watch-file.js,,Flaky - VFS1 only +tick-processor/test-tick-processor-builtin.js,,Broken test diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index a73f3bcfb..c051fe571 100644 --- a/test/runtimes/exclude/php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv @@ -8,6 +8,7 @@ ext/mbstring/tests/bug77165.phpt,, ext/mbstring/tests/bug77454.phpt,, ext/mbstring/tests/mb_convert_encoding_leak.phpt,, ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/pcre/tests/cache_limit.phpt,,Broken test - Flaky ext/session/tests/session_module_name_variation4.phpt,,Flaky ext/session/tests/session_set_save_handler_class_018.phpt,, ext/session/tests/session_set_save_handler_iface_003.phpt,, @@ -26,13 +27,14 @@ ext/standard/tests/file/php_fd_wrapper_01.phpt,, ext/standard/tests/file/php_fd_wrapper_02.phpt,, ext/standard/tests/file/php_fd_wrapper_03.phpt,, ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,b/162894969, +ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,VFS1 only failure ext/standard/tests/file/rename_variation.phpt,b/68717309, ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,VFS1 only failure ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb +ext/standard/tests/streams/proc_open_bug64438.phpt,,Flaky ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky ext/standard/tests/streams/stream_socket_sendto.phpt,, ext/standard/tests/strings/007.phpt,, diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv index 8760f8951..e9fef03b7 100644 --- a/test/runtimes/exclude/python3.7.3.csv +++ b/test/runtimes/exclude/python3.7.3.csv @@ -4,7 +4,6 @@ test_asyncore,b/162973328, test_epoll,b/162983393, test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. test_httplib,b/163000009,OSError: [Errno 98] Address already in use -test_imaplib,b/162979661, test_logging,b/162980079, test_multiprocessing_fork,,Flaky. Sometimes times out. test_multiprocessing_forkserver,,Flaky. Sometimes times out. @@ -18,4 +17,3 @@ test_selectors,b/76116849,OSError not raised with epoll test_smtplib,b/162980434,unclosed sockets test_signal,,Flaky - signal: alarm clock test_socket,b/75983380, -test_subprocess,b/162980831, diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go index e5607ac92..81cb68381 100644 --- a/test/runtimes/proctor/main.go +++ b/test/runtimes/proctor/main.go @@ -22,6 +22,7 @@ import ( "log" "os" "strings" + "syscall" "gvisor.dev/gvisor/test/runtimes/proctor/lib" ) @@ -33,6 +34,29 @@ var ( pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") ) +// setNumFilesLimit changes the NOFILE soft rlimit if it is too high. +func setNumFilesLimit() error { + // In docker containers, the default value of the NOFILE limit is + // 1048576. A few runtime tests (e.g. python:test_subprocess) + // enumerates all possible file descriptors and these tests can fail by + // timeout if the NOFILE limit is too high. On gVisor, syscalls are + // slower so these tests will need even more time to pass. + const nofile = 32768 + rLimit := syscall.Rlimit{} + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) + } + if rLimit.Cur > nofile { + rLimit.Cur = nofile + err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) + } + } + return nil +} + func main() { flag.Parse() @@ -74,6 +98,10 @@ func main() { tests = strings.Split(*testNames, ",") } + if err := setNumFilesLimit(); err != nil { + log.Fatalf("%v", err) + } + // Run tests. cmds := tr.TestCmds(tests) for _, cmd := range cmds { diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go index 78285cb0e..64e6e14db 100644 --- a/test/runtimes/runner/lib/lib.go +++ b/test/runtimes/runner/lib/lib.go @@ -34,8 +34,16 @@ import ( // RunTests is a helper that is called by main. It exists so that we can run // defered functions before exiting. It returns an exit code that should be // passed to os.Exit. -func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int { - // Get tests to exclude.. +func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, batchSize int, timeout time.Duration) int { + if partitionNum <= 0 || totalPartitions <= 0 || partitionNum > totalPartitions { + fmt.Fprintf(os.Stderr, "invalid partition %d of %d", partitionNum, totalPartitions) + return 1 + } + + // TODO(gvisor.dev/issue/1624): Remove those tests from all exclude lists + // that only fail with VFS1. + + // Get tests to exclude. excludes, err := getExcludes(excludeFile) if err != nil { fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) @@ -55,7 +63,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat // Get a slice of tests to run. This will also start a single Docker // container that will be used to run each test. The final test will // stop the Docker container. - tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes) + tests, err := getTests(ctx, d, lang, image, partitionNum, totalPartitions, batchSize, timeout, excludes) if err != nil { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) return 1 @@ -66,7 +74,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat } // getTests executes all tests as table tests. -func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { +func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, partitionNum, totalPartitions, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { // Start the container. opts := dockerutil.RunOpts{ Image: fmt.Sprintf("runtimes/%s", image), @@ -86,6 +94,14 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, // shard. tests := strings.Fields(list) sort.Strings(tests) + + partitionSize := len(tests) / totalPartitions + if partitionNum == totalPartitions { + tests = tests[(partitionNum-1)*partitionSize:] + } else { + tests = tests[(partitionNum-1)*partitionSize : partitionNum*partitionSize] + } + indices, err := testutil.TestIndicesForShard(len(tests)) if err != nil { return nil, fmt.Errorf("TestsForShard() failed: %v", err) @@ -116,8 +132,15 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, err error ) + state, err := d.Status(ctx) + if err != nil { + t.Fatalf("Could not find container status: %v", err) + } + if !state.Running { + t.Fatalf("container is not running: state = %s", state.Status) + } + go func() { - fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ",")) close(done) }() @@ -125,12 +148,12 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, select { case <-done: if err == nil { - fmt.Printf("PASS: (%v)\n\n", time.Since(now)) + fmt.Printf("PASS: (%v) %d tests passed\n", time.Since(now), len(tcs)) return } - t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) + t.Errorf("FAIL: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output) case <-time.After(timeout): - t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) + t.Errorf("TIMEOUT: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output) } }, }) diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index ec79a22c2..5b3443e36 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -25,11 +25,13 @@ import ( ) var ( - lang = flag.String("lang", "", "language runtime to test") - image = flag.String("image", "", "docker image with runtime tests") - excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") - batchSize = flag.Int("batch", 50, "number of test cases run in one command") - timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") + lang = flag.String("lang", "", "language runtime to test") + image = flag.String("image", "", "docker image with runtime tests") + excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") + partition = flag.Int("partition", 1, "partition number, this is 1-indexed") + totalPartitions = flag.Int("total_partitions", 1, "total number of partitions") + batchSize = flag.Int("batch", 50, "number of test cases run in one command") + timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") ) func main() { @@ -38,5 +40,5 @@ func main() { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout)) + os.Exit(lib.RunTests(*lang, *image, *excludeFile, *partition, *totalPartitions, *batchSize, *timeout)) } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index f66a9ceb4..b5a4ef4df 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -695,6 +695,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:socket_ip_unbound_netlink_test", +) + +syscall_test( test = "//test/syscalls/linux:socket_netdevice_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 572f39a5d..2350f7e69 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1285,6 +1285,7 @@ cc_binary( "//test/util:mount_util", "//test/util:multiprocess_util", "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", @@ -1801,10 +1802,14 @@ cc_binary( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", "//test/util:logging", + "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:platform_util", "//test/util:signal_util", + "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", "//test/util:time_util", @@ -2101,10 +2106,12 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2124,6 +2131,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2137,10 +2145,12 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2434,6 +2444,7 @@ cc_library( "@com_google_absl//absl/memory", gtest, "//test/util:posix_error", + "//test/util:save_util", "//test/util:test_util", ], alwayslink = 1, @@ -2878,6 +2889,24 @@ cc_binary( ) cc_binary( + name = "socket_ip_unbound_netlink_test", + testonly = 1, + srcs = [ + "socket_ip_unbound_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_domain_test", testonly = 1, srcs = [ @@ -3441,6 +3470,7 @@ cc_binary( "@com_google_absl//absl/strings", gtest, "//test/util:posix_error", + "//test/util:save_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index b96907b30..1635c6d0c 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -125,6 +125,16 @@ TEST(MknodTest, Socket) { ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds()); } +PosixErrorOr<FileDescriptor> OpenRetryEINTR(std::string const& path, int flags, + mode_t mode = 0) { + while (true) { + auto maybe_fd = Open(path, flags, mode); + if (maybe_fd.ok() || maybe_fd.error().errno_value() != EINTR) { + return maybe_fd; + } + } +} + TEST(MknodTest, Fifo) { const std::string fifo = NewTempAbsPath(); ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), @@ -139,14 +149,16 @@ TEST(MknodTest, Fifo) { // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); // Write-end of the pipe. - FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY)); + FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_WRONLY)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); } @@ -164,15 +176,16 @@ TEST(MknodTest, FifoOtrunc) { std::vector<char> buf(512); // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); // Write-end of the pipe. - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); + FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE( + OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); } @@ -192,14 +205,15 @@ TEST(MknodTest, FifoTruncNoOp) { std::vector<char> buf(512); // Read-end of the pipe. ScopedThread t([&fifo, &buf, &msg]() { - FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_RDONLY)); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(OpenRetryEINTR(fifo.c_str(), O_RDONLY)); EXPECT_THAT(ReadFd(fd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(msg.length())); EXPECT_EQ(msg, std::string(buf.data())); }); - FileDescriptor wfd = - ASSERT_NO_ERRNO_AND_VALUE(Open(fifo.c_str(), O_WRONLY | O_TRUNC)); + FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE( + OpenRetryEINTR(fifo.c_str(), O_WRONLY | O_TRUNC)); EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(WriteFd(wfd.get(), msg.c_str(), msg.length()), SyscallSucceedsWithValue(msg.length())); diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index e52c9cbcb..83546830d 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -592,6 +592,12 @@ TEST_F(MMapTest, ProtExec) { memcpy(reinterpret_cast<void*>(addr), machine_code, sizeof(machine_code)); +#if defined(__aarch64__) + // We use this as a memory barrier for Arm64. + ASSERT_THAT(Protect(addr, kPageSize, PROT_READ | PROT_EXEC), + SyscallSucceeds()); +#endif + func = reinterpret_cast<uint32_t (*)(void)>(addr); EXPECT_EQ(42, func()); diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 3aab25b23..d65b7d031 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -34,6 +34,7 @@ #include "test/util/mount_util.h" #include "test/util/multiprocess_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -131,7 +132,9 @@ TEST(MountTest, UmountDetach) { ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "tmpfs", 0, "mode=0700", /* umountflags= */ MNT_DETACH)); const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_NE(before.st_ino, after.st_ino); + EXPECT_FALSE(before.st_dev == after.st_dev && before.st_ino == after.st_ino) + << "mount point has device number " << before.st_dev + << " and inode number " << before.st_ino << " before and after mount"; // Create files in the new mount. constexpr char kContents[] = "no no no"; @@ -147,12 +150,14 @@ TEST(MountTest, UmountDetach) { // Unmount the tmpfs. mount.Release()(); - // Only check for inode number equality if the directory is not in overlayfs. - // If xino option is not enabled and if all overlayfs layers do not belong to - // the same filesystem then "the value of st_ino for directory objects may not - // be persistent and could change even while the overlay filesystem is - // mounted." -- Documentation/filesystems/overlayfs.txt - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + // Inode numbers for gofer-accessed files may change across save/restore. + // + // For overlayfs, if xino option is not enabled and if all overlayfs layers do + // not belong to the same filesystem then "the value of st_ino for directory + // objects may not be persistent and could change even while the overlay + // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(before.st_ino, after2.st_ino); } @@ -214,18 +219,23 @@ TEST(MountTest, MountTmpfs) { const struct stat s = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(s.st_mode, S_IFDIR | 0700); - EXPECT_NE(s.st_ino, before.st_ino); + EXPECT_FALSE(before.st_dev == s.st_dev && before.st_ino == s.st_ino) + << "mount point has device number " << before.st_dev + << " and inode number " << before.st_ino << " before and after mount"; EXPECT_NO_ERRNO(Open(JoinPath(dir.path(), "foo"), O_CREAT | O_RDWR, 0777)); } // Now that dir is unmounted again, we should have the old inode back. - // Only check for inode number equality if the directory is not in overlayfs. - // If xino option is not enabled and if all overlayfs layers do not belong to - // the same filesystem then "the value of st_ino for directory objects may not - // be persistent and could change even while the overlay filesystem is - // mounted." -- Documentation/filesystems/overlayfs.txt - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + // + // Inode numbers for gofer-accessed files may change across save/restore. + // + // For overlayfs, if xino option is not enabled and if all overlayfs layers do + // not belong to the same filesystem then "the value of st_ino for directory + // objects may not be persistent and could change even while the overlay + // filesystem is mounted." -- Documentation/filesystems/overlayfs.txt + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); EXPECT_EQ(before.st_ino, after.st_ino); } diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index b558e3a01..a7c46adbf 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -664,6 +664,17 @@ TEST_P(RawPacketTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +TEST_P(RawPacketTest, GetSocketAcceptConn) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e8fcc4439..7a0f33dff 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -26,6 +26,7 @@ #include <string.h> #include <sys/mman.h> #include <sys/prctl.h> +#include <sys/ptrace.h> #include <sys/stat.h> #include <sys/statfs.h> #include <sys/utsname.h> @@ -512,6 +513,414 @@ TEST(ProcSelfAuxv, EntryValues) { EXPECT_EQ(i, proc_auxv.size()); } +// Just open and read a part of /proc/self/mem, check that we can read an item. +TEST(ProcPidMem, Read) { + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + char input[] = "hello-world"; + char output[sizeof(input)]; + ASSERT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(input)), + SyscallSucceedsWithValue(sizeof(input))); + ASSERT_STREQ(input, output); +} + +// Perform read on an unmapped region. +TEST(ProcPidMem, Unmapped) { + // Strategy: map then unmap, so we have a guaranteed unmapped region + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + // Fill it with things + memset(mapping.ptr(), 'x', mapping.len()); + char expected = 'x', output; + ASSERT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(mapping.ptr())), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(expected, output); + + // Unmap region again + ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds()); + + // Now we want EIO error + ASSERT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(mapping.ptr())), + SyscallFailsWithErrno(EIO)); +} + +// Perform read repeatedly to verify offset change. +TEST(ProcPidMem, RepeatedRead) { + auto const num_reads = 3; + char expected[] = "01234567890abcdefghijkl"; + char output[sizeof(expected) / num_reads]; + + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + for (auto i = 0; i < num_reads; i++) { + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[i * sizeof(output)], output, sizeof(output)), + 0); + } +} + +// Perform seek operations repeatedly. +TEST(ProcPidMem, RepeatedSeek) { + auto const num_reads = 3; + char expected[] = "01234567890abcdefghijkl"; + char output[sizeof(expected) / num_reads]; + + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + ASSERT_THAT(lseek(memfd.get(), reinterpret_cast<off_t>(&expected), SEEK_SET), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + // Read from start + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0); + // Skip ahead one read + ASSERT_THAT(lseek(memfd.get(), sizeof(output), SEEK_CUR), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected) + + sizeof(output) * 2)); + // Do read again + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[2 * sizeof(output)], output, sizeof(output)), 0); + // Skip back three reads + ASSERT_THAT(lseek(memfd.get(), -3 * sizeof(output), SEEK_CUR), + SyscallSucceedsWithValue(reinterpret_cast<off_t>(&expected))); + // Do read again + ASSERT_THAT(read(memfd.get(), &output, sizeof(output)), + SyscallSucceedsWithValue(sizeof(output))); + ASSERT_EQ(strncmp(&expected[0 * sizeof(output)], output, sizeof(output)), 0); + // Check that SEEK_END does not work + ASSERT_THAT(lseek(memfd.get(), 0, SEEK_END), SyscallFailsWithErrno(EINVAL)); +} + +// Perform read past an allocated memory region. +TEST(ProcPidMem, PartialRead) { + // Strategy: map large region, then do unmap and remap smaller region + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/self/mem", O_RDONLY)); + + Mapping mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(2 * kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)); + ASSERT_THAT(munmap(mapping.ptr(), mapping.len()), SyscallSucceeds()); + Mapping smaller_mapping = ASSERT_NO_ERRNO_AND_VALUE( + Mmap(mapping.ptr(), kPageSize, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + + // Fill it with things + memset(smaller_mapping.ptr(), 'x', smaller_mapping.len()); + + // Now we want no error + char expected[] = {'x'}; + std::unique_ptr<char[]> output(new char[kPageSize]); + off_t read_offset = + reinterpret_cast<off_t>(smaller_mapping.ptr()) + kPageSize - 1; + ASSERT_THAT( + pread(memfd.get(), output.get(), sizeof(output.get()), read_offset), + SyscallSucceedsWithValue(sizeof(expected))); + // Since output is larger, than expected we have to do manual compare + ASSERT_EQ(expected[0], (output).get()[0]); +} + +// Perform read on /proc/[pid]/mem after exit. +TEST(ProcPidMem, AfterExit) { + int pfd1[2] = {}; + int pfd2[2] = {}; + + char expected[] = "hello-world"; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char ok = 1; + TEST_CHECK(WriteFd(pfd1[1], &ok, sizeof(ok)) == sizeof(ok)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Wait for child to be alive and well + char ok = 0; + EXPECT_THAT(ReadFd(pfd1[0], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + + // Expect that we can read + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(&expected)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + + // Expect that we can't read anymore + EXPECT_THAT(pread(memfd.get(), &output, sizeof(output), + reinterpret_cast<off_t>(&expected)), + SyscallSucceedsWithValue(0)); +} + +// Read from /proc/[pid]/mem with different UID/GID and attached state. +TEST(ProcPidMem, DifferentUserAttached) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_DAC_OVERRIDE))); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_PTRACE))); + + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + ScopedThread([&] { + // Attach to child subprocess without stopping it + EXPECT_THAT(ptrace(PTRACE_SEIZE, child_pid, NULL, NULL), SyscallSucceeds()); + + // Keep capabilities after setuid + EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds()); + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + + // Only restore CAP_SYS_PTRACE and CAP_DAC_OVERRIDE + EXPECT_NO_ERRNO(SetCapability(CAP_SYS_PTRACE, true)); + EXPECT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, true)); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + char expected[] = "hello-world"; + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(target_location)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0) + << " status " << status; + }); +} + +// Attempt to read from /proc/[pid]/mem with different UID/GID. +TEST(ProcPidMem, DifferentUser) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID))); + + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + ScopedThread([&] { + constexpr int kNobody = 65534; + EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); + + // Attempt to open /proc/[child_pid]/mem + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + EXPECT_THAT(open(mempath.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); + }); +} + +// Perform read on /proc/[pid]/mem with same UID/GID. +TEST(ProcPidMem, SameUser) { + int pfd1[2] = {}; + int pfd2[2] = {}; + + ASSERT_THAT(pipe(pfd1), SyscallSucceeds()); + ASSERT_THAT(pipe(pfd2), SyscallSucceeds()); + + // Create child process + pid_t const child_pid = fork(); + if (child_pid == 0) { + // Close reading end of first pipe + close(pfd1[0]); + + // Tell parent about location of input + char input[] = "hello-world"; + off_t input_location = reinterpret_cast<off_t>(input); + TEST_CHECK(WriteFd(pfd1[1], &input_location, sizeof(input_location)) == + sizeof(input_location)); + TEST_PCHECK(close(pfd1[1]) == 0); + + // Close writing end of second pipe + TEST_PCHECK(close(pfd2[1]) == 0); + + // Await parent OK to die + char ok = 0; + TEST_CHECK(ReadFd(pfd2[0], &ok, sizeof(ok)) == sizeof(ok)); + + // Close rest pipes + TEST_PCHECK(close(pfd2[0]) == 0); + _exit(0); + } + // In parent process. + ASSERT_THAT(child_pid, SyscallSucceeds()); + + // Close writing end of first pipe + EXPECT_THAT(close(pfd1[1]), SyscallSucceeds()); + + // Read target location from child + off_t target_location; + EXPECT_THAT(ReadFd(pfd1[0], &target_location, sizeof(target_location)), + SyscallSucceedsWithValue(sizeof(target_location))); + // Close reading end of first pipe + EXPECT_THAT(close(pfd1[0]), SyscallSucceeds()); + + // Open /proc/pid/mem fd + std::string mempath = absl::StrCat("/proc/", child_pid, "/mem"); + auto memfd = ASSERT_NO_ERRNO_AND_VALUE(Open(mempath, O_RDONLY)); + char expected[] = "hello-world"; + char output[sizeof(expected)]; + EXPECT_THAT(pread(memfd.get(), output, sizeof(output), + reinterpret_cast<off_t>(target_location)), + SyscallSucceedsWithValue(sizeof(output))); + EXPECT_STREQ(expected, output); + + // Tell proc its ok to go + EXPECT_THAT(close(pfd2[0]), SyscallSucceeds()); + char ok = 1; + EXPECT_THAT(WriteFd(pfd2[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + EXPECT_THAT(close(pfd2[1]), SyscallSucceeds()); + + // Expect termination + int status; + ASSERT_THAT(waitpid(child_pid, &status, 0), SyscallSucceeds()); +} + // Just open and read /proc/self/maps, check that we can find [stack] TEST(ProcSelfMaps, Basic) { auto proc_self_maps = diff --git a/test/syscalls/linux/proc_pid_smaps.cc b/test/syscalls/linux/proc_pid_smaps.cc index 9fb1b3a2c..738923822 100644 --- a/test/syscalls/linux/proc_pid_smaps.cc +++ b/test/syscalls/linux/proc_pid_smaps.cc @@ -191,7 +191,7 @@ PosixErrorOr<std::vector<ProcPidSmapsEntry>> ParseProcPidSmaps( // amount of whitespace). if (!entry) { std::cerr << "smaps line not considered a maps line: " - << maybe_maps_entry.error_message() << std::endl; + << maybe_maps_entry.error().message() << std::endl; return PosixError( EINVAL, absl::StrCat("smaps field line without preceding maps line: ", l)); diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 926690eb8..13c19d4a8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -30,10 +30,13 @@ #include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/fs_util.h" #include "test/util/logging.h" +#include "test/util/memory_util.h" #include "test/util/multiprocess_util.h" #include "test/util/platform_util.h" #include "test/util/signal_util.h" +#include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" #include "test/util/time_util.h" @@ -113,10 +116,21 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { // except disabled. SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 0); - constexpr long kBeforePokeDataValue = 10; - constexpr long kAfterPokeDataValue = 20; + // Test PTRACE_POKE/PEEKDATA on both anonymous and file mappings. + const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + ASSERT_NO_ERRNO(Truncate(file.path(), kPageSize)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + const auto file_mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap( + nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - volatile long word = kBeforePokeDataValue; + constexpr long kBeforePokeDataAnonValue = 10; + constexpr long kAfterPokeDataAnonValue = 20; + constexpr long kBeforePokeDataFileValue = 0; // implicit, due to truncate() + constexpr long kAfterPokeDataFileValue = 30; + + volatile long anon_word = kBeforePokeDataAnonValue; + auto* file_word_ptr = static_cast<volatile long*>(file_mapping.ptr()); pid_t const child_pid = fork(); if (child_pid == 0) { @@ -134,12 +148,22 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { MaybeSave(); TEST_CHECK(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - // Replace the value of word in the parent process with kAfterPokeDataValue. - long const parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &word, 0); + // Replace the value of anon_word in the parent process with + // kAfterPokeDataAnonValue. + long parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &anon_word, 0); + MaybeSave(); + TEST_CHECK(parent_word == kBeforePokeDataAnonValue); + TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, &anon_word, + kAfterPokeDataAnonValue) == 0); + MaybeSave(); + + // Replace the value pointed to by file_word_ptr in the mapped file with + // kAfterPokeDataFileValue, via the parent process' mapping. + parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, file_word_ptr, 0); MaybeSave(); - TEST_CHECK(parent_word == kBeforePokeDataValue); - TEST_PCHECK( - ptrace(PTRACE_POKEDATA, parent_pid, &word, kAfterPokeDataValue) == 0); + TEST_CHECK(parent_word == kBeforePokeDataFileValue); + TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, file_word_ptr, + kAfterPokeDataFileValue) == 0); MaybeSave(); // Detach from the parent and suppress the SIGSTOP. If the SIGSTOP is not @@ -160,7 +184,8 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { << " status " << status; // Check that the child's PTRACE_POKEDATA was effective. - EXPECT_EQ(kAfterPokeDataValue, word); + EXPECT_EQ(kAfterPokeDataAnonValue, anon_word); + EXPECT_EQ(kAfterPokeDataFileValue, *file_word_ptr); } TEST(PtraceTest, GetSigMask) { diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 1b9dbc584..bd779da92 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -438,6 +438,19 @@ TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +// Test getsockopt for SO_ACCEPTCONN. +TEST_F(RawSocketICMPTest, GetSocketAcceptConn) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // We're going to receive both the echo request and reply, but the order is // indeterminate. diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index e9b131ca9..890f4a246 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <signal.h> #include <sys/ipc.h> #include <sys/sem.h> #include <sys/types.h> @@ -486,6 +487,292 @@ TEST(SemaphoreTest, SemIpcSet) { ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallFailsWithErrno(EACCES)); } +TEST(SemaphoreTest, SemCtlIpcStat) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + const uid_t kUid = getuid(); + const gid_t kGid = getgid(); + time_t start_time = time(nullptr); + + AutoSem sem(semget(IPC_PRIVATE, 10, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + struct semid_ds ds; + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds()); + + EXPECT_EQ(ds.sem_perm.__key, IPC_PRIVATE); + EXPECT_EQ(ds.sem_perm.uid, kUid); + EXPECT_EQ(ds.sem_perm.gid, kGid); + EXPECT_EQ(ds.sem_perm.cuid, kUid); + EXPECT_EQ(ds.sem_perm.cgid, kGid); + EXPECT_EQ(ds.sem_perm.mode, 0600); + // Last semop time is not set on creation. + EXPECT_EQ(ds.sem_otime, 0); + EXPECT_GE(ds.sem_ctime, start_time); + EXPECT_EQ(ds.sem_nsems, 10); + + // The timestamps only have a resolution of seconds; slow down so we actually + // see the timestamps change. + absl::SleepFor(absl::Seconds(1)); + + // Set semid_ds structure of the set. + auto last_ctime = ds.sem_ctime; + start_time = time(nullptr); + struct semid_ds semid_to_set = {}; + semid_to_set.sem_perm.uid = kUid; + semid_to_set.sem_perm.gid = kGid; + semid_to_set.sem_perm.mode = 0666; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds()); + struct sembuf buf = {}; + buf.sem_op = 1; + ASSERT_THAT(semop(sem.get(), &buf, 1), SyscallSucceeds()); + + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), SyscallSucceeds()); + EXPECT_EQ(ds.sem_perm.mode, 0666); + EXPECT_GE(ds.sem_otime, start_time); + EXPECT_GT(ds.sem_ctime, last_ctime); + + // An invalid semid fails the syscall with errno EINVAL. + EXPECT_THAT(semctl(sem.get() + 1, 0, IPC_STAT, &ds), + SyscallFailsWithErrno(EINVAL)); + + // Make semaphore not readable and check the signal fails. + semid_to_set.sem_perm.mode = 0200; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &semid_to_set), SyscallSucceeds()); + EXPECT_THAT(semctl(sem.get(), 0, IPC_STAT, &ds), + SyscallFailsWithErrno(EACCES)); +} + +// Calls semctl(semid, 0, cmd) until the returned value is >= target, an +// internal timeout expires, or semctl returns an error. +PosixErrorOr<int> WaitSemctl(int semid, int target, int cmd) { + constexpr absl::Duration timeout = absl::Seconds(10); + const auto deadline = absl::Now() + timeout; + int semcnt = 0; + while (absl::Now() < deadline) { + semcnt = semctl(semid, 0, cmd); + if (semcnt < 0) { + return PosixError(errno, "semctl(GETZCNT) failed"); + } + if (semcnt >= target) { + break; + } + absl::SleepFor(absl::Milliseconds(10)); + } + return semcnt; +} + +TEST(SemaphoreTest, SemopGetzcnt) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + // Create a write only semaphore set. + AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + // No read permission to retrieve semzcnt. + EXPECT_THAT(semctl(sem.get(), 0, GETZCNT), SyscallFailsWithErrno(EACCES)); + + // Remove the calling thread's read permission. + struct semid_ds ds = {}; + ds.sem_perm.uid = getuid(); + ds.sem_perm.gid = getgid(); + ds.sem_perm.mode = 0600; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds()); + + std::vector<pid_t> children; + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); + + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + constexpr size_t kLoops = 10; + for (auto i = 0; i < kLoops; i++) { + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); + _exit(0); + } + children.push_back(child_pid); + } + + EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETZCNT), + IsPosixErrorOkAndHolds(kLoops)); + // Set semval to 0, which wakes up children that sleep on the semop. + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 0), SyscallSucceeds()); + for (const auto& child_pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0); +} + +TEST(SemaphoreTest, SemopGetzcntOnSetRemoval) { + auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT); + ASSERT_THAT(semid, SyscallSucceeds()); + ASSERT_THAT(semctl(semid, 0, SETVAL, 1), SyscallSucceeds()); + ASSERT_EQ(semctl(semid, 0, GETZCNT), 0); + + auto child_pid = fork(); + if (child_pid == 0) { + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + + // Ensure that wait will only unblock when the semaphore is removed. On + // EINTR retry it may race with deletion and return EINVAL. + TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(semid, 1, GETZCNT), IsPosixErrorOkAndHolds(1)); + // Remove the semaphore set, which fails the sleep semop. + ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds()); + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_THAT(semctl(semid, 0, GETZCNT), SyscallFailsWithErrno(EINVAL)); +} + +TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); + ASSERT_EQ(semctl(sem.get(), 0, GETZCNT), 0); + + // Saving will cause semop() to be spuriously interrupted. + DisableSave ds; + + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR); + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + + TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(sem.get(), 1, GETZCNT), IsPosixErrorOkAndHolds(1)); + // Send a signal to the child, which fails the sleep semop. + ASSERT_EQ(kill(child_pid, SIGHUP), 0); + + ds.reset(); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0); +} + +TEST(SemaphoreTest, SemopGetncnt) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + // Create a write only semaphore set. + AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + // No read permission to retrieve semzcnt. + EXPECT_THAT(semctl(sem.get(), 0, GETNCNT), SyscallFailsWithErrno(EACCES)); + + // Remove the calling thread's read permission. + struct semid_ds ds = {}; + ds.sem_perm.uid = getuid(); + ds.sem_perm.gid = getgid(); + ds.sem_perm.mode = 0600; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds()); + + std::vector<pid_t> children; + + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + constexpr size_t kLoops = 10; + for (auto i = 0; i < kLoops; i++) { + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); + _exit(0); + } + children.push_back(child_pid); + } + EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETNCNT), + IsPosixErrorOkAndHolds(kLoops)); + // Set semval to 1, which wakes up children that sleep on the semop. + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, kLoops), SyscallSucceeds()); + for (const auto& child_pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0); +} + +TEST(SemaphoreTest, SemopGetncntOnSetRemoval) { + auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT); + ASSERT_THAT(semid, SyscallSucceeds()); + ASSERT_EQ(semctl(semid, 0, GETNCNT), 0); + + auto child_pid = fork(); + if (child_pid == 0) { + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + + // Ensure that wait will only unblock when the semaphore is removed. On + // EINTR retry it may race with deletion and return EINVAL + TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(semid, 1, GETNCNT), IsPosixErrorOkAndHolds(1)); + // Remove the semaphore set, which fails the sleep semop. + ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds()); + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_THAT(semctl(semid, 0, GETNCNT), SyscallFailsWithErrno(EINVAL)); +} + +TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + ASSERT_EQ(semctl(sem.get(), 0, GETNCNT), 0); + + // Saving will cause semop() to be spuriously interrupted. + DisableSave ds; + + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR); + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + + TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR); + _exit(0); + } + EXPECT_THAT(WaitSemctl(sem.get(), 1, GETNCNT), IsPosixErrorOkAndHolds(1)); + // Send a signal to the child, which fails the sleep semop. + ASSERT_EQ(kill(child_pid, SIGHUP), 0); + + ds.reset(); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index a8bfb01f1..cf0977118 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -25,9 +25,11 @@ #include "absl/time/time.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/timer_util.h" namespace gvisor { namespace testing { @@ -629,6 +631,57 @@ TEST(SendFileTest, SendFileToPipe) { SyscallSucceedsWithValue(kDataSize)); } +static volatile int signaled = 0; +void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } + +TEST(SendFileTest, ToEventFDDoesNotSpin_NoRandomSave) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + + // Write the maximum value of an eventfd to a file. + const uint64_t kMaxEventfdValue = 0xfffffffffffffffe; + const auto tempfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const auto tempfd = ASSERT_NO_ERRNO_AND_VALUE(Open(tempfile.path(), O_RDWR)); + ASSERT_THAT( + pwrite(tempfd.get(), &kMaxEventfdValue, sizeof(kMaxEventfdValue), 0), + SyscallSucceedsWithValue(sizeof(kMaxEventfdValue))); + + // Set the eventfd's value to 1. + const uint64_t kOne = 1; + ASSERT_THAT(write(efd.get(), &kOne, sizeof(kOne)), + SyscallSucceedsWithValue(sizeof(kOne))); + + // Set up signal handler. + struct sigaction sa = {}; + sa.sa_sigaction = SigUsr1Handler; + sa.sa_flags = SA_SIGINFO; + const auto cleanup_sigact = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); + + // Send SIGUSR1 to this thread in 1 second. + struct sigevent sev = {}; + sev.sigev_notify = SIGEV_THREAD_ID; + sev.sigev_signo = SIGUSR1; + sev.sigev_notify_thread_id = gettid(); + auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); + struct itimerspec its = {}; + its.it_value = absl::ToTimespec(absl::Seconds(1)); + DisableSave ds; // Asserting an EINTR. + ASSERT_NO_ERRNO(timer.Set(0, its)); + + // Sendfile from tempfd to the eventfd. Since the eventfd is not already at + // its maximum value, the eventfd is "ready for writing"; however, since the + // eventfd's existing value plus the new value would exceed the maximum, the + // write should internally fail with EWOULDBLOCK. In this case, sendfile() + // should block instead of spinning, and eventually be interrupted by our + // timer. See b/172075629. + EXPECT_THAT( + sendfile(efd.get(), tempfd.get(), nullptr, sizeof(kMaxEventfdValue)), + SyscallFailsWithErrno(EINTR)); + + // Signal should have been handled. + EXPECT_EQ(signaled, 1); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 11fcec443..e19a83413 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -350,6 +350,10 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownListen) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + // TODO(b/157236388): Remove Disable save after bug is fixed. S/R test can + // fail because the last socket may not be delivered to the accept queue + // by the time connect returns. + DisableSave ds; for (int i = 0; i < kBacklog; i++) { auto client = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -554,7 +558,11 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { }); } -TEST_P(SocketInetLoopbackTest, TCPbacklog) { +// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -567,7 +575,8 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog) { ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), 2), SyscallSucceeds()); + constexpr int kBacklogSize = 2; + ASSERT_THAT(listen(listen_fd.get(), kBacklogSize), SyscallSucceeds()); // Get the port bound by the listening socket. socklen_t addrlen = listener.addr_len; @@ -931,7 +940,7 @@ void setupTimeWaitClose(const TestAddress* listener, } // shutdown to trigger TIME_WAIT. - ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds()); + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); { const int kTimeout = 10000; struct pollfd pfd = { @@ -941,7 +950,8 @@ void setupTimeWaitClose(const TestAddress* listener, ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); ASSERT_EQ(pfd.revents, POLLIN); } - ScopedThread t([&]() { + ASSERT_THAT(shutdown(passive_closefd.get(), SHUT_WR), SyscallSucceeds()); + { constexpr int kTimeout = 10000; constexpr int16_t want_events = POLLHUP; struct pollfd pfd = { @@ -949,11 +959,8 @@ void setupTimeWaitClose(const TestAddress* listener, .events = want_events, }; ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - }); + } - passive_closefd.reset(); - t.Join(); - active_closefd.reset(); // This sleep is needed to reduce flake to ensure that the passive-close // ensures the state transitions to CLOSE from LAST_ACK. absl::SleepFor(absl::Seconds(1)); @@ -1143,6 +1150,9 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + // TODO(b/157236388): Reenable Cooperative S/R once bug is fixed. + DisableSave ds; ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), connector.addr_len), diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index 3f2c0fdf2..f69f8f99f 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -472,5 +472,19 @@ TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); } +// Test getsockopt for SO_ACCEPTCONN on udp socket. +TEST_P(UDPSocketPairTest, GetSocketAcceptConn) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 8f7ccc868..029f1e872 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -454,23 +454,15 @@ TEST_P(IPUnboundSocketTest, SetReuseAddr) { INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, - ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( + ::testing::ValuesIn(VecCat<SocketKind>( ApplyVec<SocketKind>(IPv4UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector<int>{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv6UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector<int>{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv4TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv6TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})))))); + std::vector{0, SOCK_NONBLOCK})))); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound_netlink.cc b/test/syscalls/linux/socket_ip_unbound_netlink.cc new file mode 100644 index 000000000..6036bfcaf --- /dev/null +++ b/test/syscalls/linux/socket_ip_unbound_netlink.cc @@ -0,0 +1,104 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of IP sockets. +using IPv6UnboundSocketTest = SimpleSocketTest; + +TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved + // across save/restore. + DisableSave ds; + + // Delete the loopback address from the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/128, &in6addr_loopback, + sizeof(in6addr_loopback))); + Cleanup defer_addr_removal = + Cleanup([loopback_link = std::move(loopback_link)] { + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/128, &in6addr_loopback, + sizeof(in6addr_loopback))); + }); + + TestAddress addr = V6Loopback(); + reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = 65535; + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv6UnboundSocketTest, + ::testing::ValuesIn(std::vector<SocketKind>{ + IPv6UDPUnboundSocket(0), + IPv6TCPUnboundSocket(0)})); + +using IPv4UnboundSocketTest = SimpleSocketTest; + +TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved + // across save/restore. + DisableSave ds; + + // Delete the loopback address from the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr laddr; + laddr.s_addr = htonl(INADDR_LOOPBACK); + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/8, &laddr, sizeof(laddr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(laddr)] { + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/8, &addr, sizeof(addr))); + }); + TestAddress addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = 65535; + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv4UnboundSocketTest, + ::testing::ValuesIn(std::vector<SocketKind>{ + IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0)})); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index a72c76c97..b3f54e7f6 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -28,6 +28,7 @@ #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/test_util.h" namespace gvisor { @@ -75,7 +76,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -209,7 +210,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -265,7 +266,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -321,7 +322,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -377,7 +378,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -437,7 +438,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -497,7 +498,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -553,7 +554,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -609,7 +610,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -669,7 +670,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -727,7 +728,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -785,7 +786,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -845,7 +846,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); @@ -919,7 +920,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -977,7 +978,7 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -1330,8 +1331,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1409,8 +1410,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1432,8 +1433,8 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { char recv_buf[sizeof(send_buf)] = {}; for (auto& sockets : socket_pairs) { - ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, - sizeof(recv_buf), 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(sockets->second_fd(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -1486,7 +1487,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1530,7 +1531,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), PosixErrorIs(EAGAIN, ::testing::_)); } @@ -1580,7 +1581,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1627,7 +1628,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1678,7 +1679,7 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; ASSERT_THAT( - RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1737,7 +1738,7 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // of the other sockets to have received it, but we will check that later. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + RecvTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(send_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1745,9 +1746,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // Verify that no other messages were received. for (auto& socket : sockets) { char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - PosixErrorIs(EAGAIN, ::testing::_)); + EXPECT_THAT( + RecvTimeout(socket->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -2108,6 +2109,9 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { constexpr int kMessageSize = 10; + // Saving during each iteration of the following loop is too expensive. + DisableSave ds; + for (int i = 0; i < 100; ++i) { // Send a new message to the REUSEADDR/REUSEPORT group. We use a new socket // each time so that a new ephemerial port will be used each time. This @@ -2120,16 +2124,18 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(sizeof(send_buf))); } + ds.reset(); + // Check that both receivers got messages. This checks that we are using load // balancing (REUSEPORT) instead of the most recently bound socket // (REUSEADDR). char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - IsPosixErrorOkAndHolds(kMessageSize)); - EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), - 1 /*timeout*/), - IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT( + RecvTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT( + RecvTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); } // Test that socket will receive packet info control message. @@ -2193,8 +2199,8 @@ TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); received_msg.msg_control = received_cmsg_buf; - ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); + ASSERT_THAT(RecvMsgTimeout(receiver->get(), &received_msg, 1 /*timeout*/), + IsPosixErrorOkAndHolds(kDataLength)); cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); ASSERT_NE(cmsg, nullptr); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc index 49a0f06d9..875016812 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -40,17 +40,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { /*prefixlen=*/24, &addr, sizeof(addr))); Cleanup defer_addr_removal = Cleanup( [loopback_link = std::move(loopback_link), addr = std::move(addr)] { - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses - // via netlink is supported. - EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); - } else { - EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, - sizeof(addr))); - } + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); }); auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -124,17 +116,9 @@ TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { 24 /* prefixlen */, &addr, sizeof(addr))); Cleanup defer_addr_removal = Cleanup( [loopback_link = std::move(loopback_link), addr = std::move(addr)] { - if (IsRunningOnGvisor()) { - // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses - // via netlink is supported. - EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); - } else { - EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, - /*prefixlen=*/24, &addr, - sizeof(addr))); - } + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); }); TestAddress broadcast_address("SubnetBroadcastAddress"); diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index 241ddad74..ee3c08770 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -511,53 +511,42 @@ TEST(NetlinkRouteTest, LookupAll) { ASSERT_GT(count, 0); } -TEST(NetlinkRouteTest, AddAddr) { +TEST(NetlinkRouteTest, AddAndRemoveAddr) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + // Don't do cooperative save/restore because netstack state is not restored. + // TODO(gvisor.dev/issue/4595): enable cooperative save tests. + const DisableSave ds; Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); - FileDescriptor fd = - ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifa; - struct rtattr rtattr; - struct in_addr addr; - char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; - }; - - struct request req = {}; - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_seq = kSeq; - req.ifa.ifa_family = AF_INET; - req.ifa.ifa_prefixlen = 24; - req.ifa.ifa_flags = 0; - req.ifa.ifa_scope = 0; - req.ifa.ifa_index = loopback_link.index; - req.rtattr.rta_type = IFA_LOCAL; - req.rtattr.rta_len = RTA_LENGTH(sizeof(req.addr)); - inet_pton(AF_INET, "10.0.0.1", &req.addr); - req.hdr.nlmsg_len = - NLMSG_LENGTH(sizeof(req.ifa)) + NLMSG_ALIGN(req.rtattr.rta_len); + struct in_addr addr; + ASSERT_EQ(inet_pton(AF_INET, "10.0.0.1", &addr), 1); // Create should succeed, as no such address in kernel. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_ACK; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); + + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + // First delete should succeed, as address exists. + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + + // Second delete should fail, as address no longer exists. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EADDRNOTAVAIL, ::testing::_)); + }); // Replace an existing address should succeed. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_NO_ERRNO( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len)); + ASSERT_NO_ERRNO(LinkReplaceLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); // Create exclusive should fail, as we created the address above. - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL | NLM_F_ACK; - req.hdr.nlmsg_seq++; - EXPECT_THAT( - NetlinkRequestAckOrError(fd, req.hdr.nlmsg_seq, &req, req.hdr.nlmsg_len), - PosixErrorIs(EEXIST, ::testing::_)); + EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EEXIST, ::testing::_)); } // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index 7a0bad4cb..46f749c7c 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -29,6 +29,8 @@ constexpr uint32_t kSeq = 12345; // Types of address modifications that may be performed on an interface. enum class LinkAddrModification { kAdd, + kAddExclusive, + kReplace, kDelete, }; @@ -40,6 +42,14 @@ PosixError PopulateNlmsghdr(LinkAddrModification modification, hdr->nlmsg_type = RTM_NEWADDR; hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; return NoError(); + case LinkAddrModification::kAddExclusive: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_EXCL | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kReplace: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_REPLACE | NLM_F_ACK; + return NoError(); case LinkAddrModification::kDelete: hdr->nlmsg_type = RTM_DELADDR; hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; @@ -144,6 +154,18 @@ PosixError LinkAddLocalAddr(int index, int family, int prefixlen, LinkAddrModification::kAdd); } +PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAddExclusive); +} + +PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kReplace); +} + PosixError LinkDelLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index e5badca70..eaa91ad79 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -39,10 +39,19 @@ PosixErrorOr<std::vector<Link>> DumpLinks(); // Returns the loopback link on the system. ENOENT if not found. PosixErrorOr<Link> LoopbackLink(); -// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface. +// LinkAddLocalAddr adds a new IFA_LOCAL address to the interface. PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkAddExclusiveLocalAddr adds a new IFA_LOCAL address with NLM_F_EXCL flag +// to the interface. +PosixError LinkAddExclusiveLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + +// LinkReplaceLocalAddr replaces an IFA_LOCAL address on the interface. +PosixError LinkReplaceLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. PosixError LinkDelLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index e11792309..a760581b5 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -753,8 +753,7 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { return ret; } -PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, - int timeout) { +PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout) { fd_set rfd; struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; FD_ZERO(&rfd); @@ -767,6 +766,19 @@ PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, return ret; } +PosixErrorOr<int> RecvMsgTimeout(int sock, struct msghdr* msg, int timeout) { + fd_set rfd; + struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; + FD_ZERO(&rfd); + FD_SET(sock, &rfd); + + int ret; + RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to)); + RETURN_ERROR_IF_SYSCALL_FAIL( + ret = RetryEINTR(recvmsg)(sock, msg, MSG_DONTWAIT)); + return ret; +} + void RecvNoData(int sock) { char data = 0; struct iovec iov; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 468bc96e0..5e205339f 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -467,9 +467,12 @@ PosixError FreeAvailablePort(int port); // SendMsg converts a buffer to an iovec and adds it to msg before sending it. PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size); -// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock. -PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, - int timeout); +// RecvTimeout calls select on sock with timeout and then calls recv on sock. +PosixErrorOr<int> RecvTimeout(int sock, char buf[], int buf_size, int timeout); + +// RecvMsgTimeout calls select on sock with timeout and then calls recvmsg on +// sock. +PosixErrorOr<int> RecvMsgTimeout(int sock, msghdr* msg, int timeout); // RecvNoData checks that no data is receivable on sock. void RecvNoData(int sock); diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 1edcb15a7..ad9c4bf37 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -121,6 +121,19 @@ TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) { EXPECT_EQ(0, memcmp(&got_linger, &sl, length)); } +TEST_P(StreamUnixSocketPairTest, GetSocketAcceptConn) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index a1d2b9b11..c2369db54 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -26,9 +26,11 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/timer_util.h" namespace gvisor { namespace testing { @@ -772,6 +774,59 @@ TEST(SpliceTest, FromPipeToDevZero) { SyscallSucceedsWithValue(0)); } +static volatile int signaled = 0; +void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } + +TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin_NoRandomSave) { + // Writes to a pipe that are less than PIPE_BUF must be atomic. This test + // creates a pipe with only 128 bytes of capacity (< PIPE_BUF) and checks that + // splicing to the pipe does not spin. See b/170743336. + + // Create a file with one page of data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), absl::string_view(buf.data(), buf.size()), + TempPath::kDefaultFileMode)); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + // Create a pipe with size 4096, and fill all but 128 bytes of it. + int p[2]; + ASSERT_THAT(pipe(p), SyscallSucceeds()); + ASSERT_THAT(fcntl(p[1], F_SETPIPE_SZ, kPageSize), SyscallSucceeds()); + const int kWriteSize = kPageSize - 128; + std::vector<char> writeBuf(kWriteSize); + RandomizeBuffer(writeBuf.data(), writeBuf.size()); + ASSERT_THAT(write(p[1], writeBuf.data(), writeBuf.size()), + SyscallSucceedsWithValue(kWriteSize)); + + // Set up signal handler. + struct sigaction sa = {}; + sa.sa_sigaction = SigUsr1Handler; + sa.sa_flags = SA_SIGINFO; + const auto cleanup_sigact = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); + + // Send SIGUSR1 to this thread in 1 second. + struct sigevent sev = {}; + sev.sigev_notify = SIGEV_THREAD_ID; + sev.sigev_signo = SIGUSR1; + sev.sigev_notify_thread_id = gettid(); + auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); + struct itimerspec its = {}; + its.it_value = absl::ToTimespec(absl::Seconds(1)); + DisableSave ds; // Asserting an EINTR. + ASSERT_NO_ERRNO(timer.Set(0, its)); + + // Now splice the file to the pipe. This should block, but not spin, and + // should return EINTR because it is interrupted by the signal. + EXPECT_THAT(splice(fd.get(), nullptr, p[1], nullptr, kPageSize, 0), + SyscallFailsWithErrno(EINTR)); + + // Alarm should have been handled. + EXPECT_EQ(signaled, 1); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 92260b1e1..6e7142a42 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -31,6 +31,7 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/save_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -328,7 +329,10 @@ TEST_F(StatTest, LeadingDoubleSlash) { ASSERT_THAT(lstat(double_slash_path.c_str(), &double_slash_st), SyscallSucceeds()); EXPECT_EQ(st.st_dev, double_slash_st.st_dev); - EXPECT_EQ(st.st_ino, double_slash_st.st_ino); + // Inode numbers for gofer-accessed files may change across save/restore. + if (!IsRunningWithSaveRestore()) { + EXPECT_EQ(st.st_ino, double_slash_st.st_ino); + } } // Test that a rename doesn't change the underlying file. @@ -346,8 +350,14 @@ TEST_F(StatTest, StatDoesntChangeAfterRename) { EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); EXPECT_EQ(st_old.st_dev, st_new.st_dev); + // Inode numbers for gofer-accessed files on which no reference is held may + // change across save/restore because the information that the gofer client + // uses to track file identity (9P QID path) is inconsistent between gofer + // processes, which are restarted across save/restore. + // // Overlay filesystems may synthesize directory inode numbers on the fly. - if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { + if (!IsRunningWithSaveRestore() && + !ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { EXPECT_EQ(st_old.st_ino, st_new.st_ino); } EXPECT_EQ(st_old.st_mode, st_new.st_mode); @@ -541,6 +551,26 @@ TEST_F(StatTest, LstatELOOPPath) { ASSERT_THAT(lstat(path.c_str(), &s), SyscallFailsWithErrno(ELOOP)); } +TEST(SimpleStatTest, DifferentFilesHaveDifferentDeviceInodeNumberPairs) { + // TODO(gvisor.dev/issue/1624): This test case fails in VFS1 save/restore + // tests because VFS1 gofer inode number assignment restarts after + // save/restore, such that the inodes for file1 and file2 (which are + // unreferenced and therefore not retained in sentry checkpoints before the + // calls to lstat()) are assigned the same inode number. + SKIP_IF(IsRunningWithVFS1() && IsRunningWithSaveRestore()); + + TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + TempPath file2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + + MaybeSave(); + struct stat st1 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file1.path())); + MaybeSave(); + struct stat st2 = ASSERT_NO_ERRNO_AND_VALUE(Lstat(file2.path())); + EXPECT_FALSE(st1.st_dev == st2.st_dev && st1.st_ino == st2.st_ino) + << "both files have device number " << st1.st_dev << " and inode number " + << st1.st_ino; +} + // Ensure that inode allocation for anonymous devices work correctly across // save/restore. In particular, inode numbers should be unique across S/R. TEST(SimpleStatTest, AnonDeviceAllocatesUniqueInodesAcrossSaveRestore) { diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 9f522f833..ebd873068 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -1725,6 +1725,63 @@ TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) { ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout)); } +// Tests that SO_ACCEPTCONN returns non zero value for listening sockets. +TEST_P(TcpSocketTest, GetSocketAcceptConnListener) { + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(listener_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 1); +} + +// Tests that SO_ACCEPTCONN returns zero value for not listening sockets. +TEST_P(TcpSocketTest, GetSocketAcceptConnNonListener) { + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); + + ASSERT_THAT(getsockopt(t_, SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + +TEST_P(SimpleTcpSocketTest, GetSocketAcceptConnWithShutdown) { + // TODO(b/171345701): Fix the TCP state for listening socket on shutdown. + SKIP_IF(IsRunningOnGvisor()); + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = + ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam())); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(s.get(), SOMAXCONN), SyscallSucceeds()); + + int got = -1; + socklen_t length = sizeof(got); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 1); + + EXPECT_THAT(shutdown(s.get(), SHUT_RD), SyscallSucceeds()); + ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ACCEPTCONN, &got, &length), + SyscallSucceeds()); + ASSERT_EQ(length, sizeof(got)); + EXPECT_EQ(got, 0); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index cac94d9e1..93a98adb1 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -322,11 +322,6 @@ TEST(IntervalTimerTest, PeriodicGroupDirectedSignal) { EXPECT_GE(counted_signals.load(), kCycles); } -// From Linux's include/uapi/asm-generic/siginfo.h. -#ifndef sigev_notify_thread_id -#define sigev_notify_thread_id _sigev_un._tid -#endif - TEST(IntervalTimerTest, PeriodicThreadDirectedSignal) { constexpr int kSigno = SIGUSR1; constexpr int kSigvalue = 42; diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 1a7673317..d65275fd3 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -679,6 +679,43 @@ TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { SyscallSucceedsWithValue(sizeof(buf))); } +TEST_P(UdpSocketTest, ConnectAndSendNoReceiver) { + ASSERT_NO_ERRNO(BindLoopback()); + // Close the socket to release the port so that we get an ICMP error. + ASSERT_THAT(close(bind_.release()), SyscallSucceeds()); + + // Connect to loopback:bind_addr_ which should *hopefully* not be bound by an + // UDP socket. There is no easy way to ensure that the UDP port is not bound + // by another conncurrently running test. *This is potentially flaky*. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + constexpr int kTimeout = 1000; + // Poll to make sure we get the ICMP error back before issuing more writes. + struct pollfd pfd = {sock_.get(), POLLERR, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + + // Next write should fail with ECONNREFUSED due to the ICMP error generated in + // response to the previous write. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ECONNREFUSED)); + + // The next write should succeed again since the last write call would have + // retrieved and cleared the socket error. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), SyscallSucceeds()); + + // Poll to make sure we get the ICMP error back before issuing more writes. + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + + // Next write should fail with ECONNREFUSED due to the ICMP error generated in + // response to the previous write. + ASSERT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(ECONNREFUSED)); +} + TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. SKIP_IF(IsRunningWithHostinet()); @@ -838,7 +875,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { // Receive the data. It works because it was sent before the connect. char received[sizeof(buf)]; EXPECT_THAT( - RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), + RecvTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), IsPosixErrorOkAndHolds(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -928,9 +965,8 @@ TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { SyscallSucceedsWithValue(1)); // We should get the data even though read has been shutdown. - EXPECT_THAT( - RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), - IsPosixErrorOkAndHolds(2)); + EXPECT_THAT(RecvTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), + IsPosixErrorOkAndHolds(2)); // Because we read less than the entire packet length, since it's a packet // based socket any subsequent reads should return EWOULDBLOCK. @@ -1698,8 +1734,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); } @@ -1714,8 +1750,8 @@ TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { SyscallSucceedsWithValue(buf.size())); std::vector<char> received(buf.size()); - ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + ASSERT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); } } @@ -1785,8 +1821,8 @@ TEST_P(UdpSocketTest, RecvBufLimits) { for (int i = 0; i < sent - 1; i++) { // Receive the data. std::vector<char> received(buf.size()); - EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), - 1 /*timeout*/), + EXPECT_THAT(RecvTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), IsPosixErrorOkAndHolds(received.size())); EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); } @@ -1851,6 +1887,22 @@ TEST_P(UdpSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } +TEST_P(UdpSocketTest, SendToZeroPort) { + char buf[8]; + struct sockaddr_storage addr = InetLoopbackAddr(); + + // Sending to an invalid port should fail. + SetPort(&addr, 0); + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); + + SetPort(&addr, 1234); + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceedsWithValue(sizeof(buf))); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/util/BUILD b/test/util/BUILD index 26c2b6a2f..1b028a477 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -155,6 +155,10 @@ cc_library( ], hdrs = ["save_util.h"], defines = select_system(), + deps = [ + ":logging", + "@com_google_absl//absl/types:optional", + ], ) cc_library( diff --git a/test/util/posix_error.cc b/test/util/posix_error.cc index cebf7e0ac..deed0c05b 100644 --- a/test/util/posix_error.cc +++ b/test/util/posix_error.cc @@ -87,7 +87,7 @@ bool PosixErrorIsMatcherCommonImpl::MatchAndExplain( return false; } - if (!message_matcher_.Matches(error.error_message())) { + if (!message_matcher_.Matches(error.message())) { return false; } diff --git a/test/util/posix_error.h b/test/util/posix_error.h index ad666bce0..b634a7f78 100644 --- a/test/util/posix_error.h +++ b/test/util/posix_error.h @@ -26,11 +26,6 @@ namespace gvisor { namespace testing { -class PosixErrorIsMatcherCommonImpl; - -template <typename T> -class PosixErrorOr; - class ABSL_MUST_USE_RESULT PosixError { public: PosixError() {} @@ -49,7 +44,8 @@ class ABSL_MUST_USE_RESULT PosixError { // PosixErrorOr. const PosixError& error() const { return *this; } - std::string error_message() const { return msg_; } + int errno_value() const { return errno_; } + std::string message() const { return msg_; } // ToString produces a full string representation of this posix error // including the printable representation of the errno and the error message. @@ -61,14 +57,8 @@ class ABSL_MUST_USE_RESULT PosixError { void IgnoreError() const {} private: - int errno_value() const { return errno_; } int errno_ = 0; std::string msg_; - - friend class PosixErrorIsMatcherCommonImpl; - - template <typename T> - friend class PosixErrorOr; }; template <typename T> @@ -94,15 +84,12 @@ class ABSL_MUST_USE_RESULT PosixErrorOr { template <typename U> PosixErrorOr& operator=(PosixErrorOr<U> other); - // Return a reference to the error or NoError(). - PosixError error() const; - - // Returns this->error().error_message(); - std::string error_message() const; - // Returns true if this PosixErrorOr contains some T. bool ok() const; + // Return a copy of the contained PosixError or NoError(). + PosixError error() const; + // Returns a reference to our current value, or CHECK-fails if !this->ok(). const T& ValueOrDie() const&; T& ValueOrDie() &; @@ -115,7 +102,6 @@ class ABSL_MUST_USE_RESULT PosixErrorOr { void IgnoreError() const {} private: - int errno_value() const; absl::variant<T, PosixError> value_; friend class PosixErrorIsMatcherCommonImpl; @@ -171,16 +157,6 @@ PosixError PosixErrorOr<T>::error() const { } template <typename T> -int PosixErrorOr<T>::errno_value() const { - return error().errno_value(); -} - -template <typename T> -std::string PosixErrorOr<T>::error_message() const { - return error().error_message(); -} - -template <typename T> bool PosixErrorOr<T>::ok() const { return absl::holds_alternative<T>(value_); } diff --git a/test/util/save_util.cc b/test/util/save_util.cc index 384d626f0..59d47e06e 100644 --- a/test/util/save_util.cc +++ b/test/util/save_util.cc @@ -21,35 +21,47 @@ #include <atomic> #include <cerrno> -#define GVISOR_COOPERATIVE_SAVE_TEST "GVISOR_COOPERATIVE_SAVE_TEST" +#include "absl/types/optional.h" namespace gvisor { namespace testing { namespace { -enum class CooperativeSaveMode { - kUnknown = 0, // cooperative_save_mode is statically-initialized to 0 - kAvailable, - kNotAvailable, -}; - -std::atomic<CooperativeSaveMode> cooperative_save_mode; - -bool CooperativeSaveEnabled() { - auto mode = cooperative_save_mode.load(); - if (mode == CooperativeSaveMode::kUnknown) { - mode = (getenv(GVISOR_COOPERATIVE_SAVE_TEST) != nullptr) - ? CooperativeSaveMode::kAvailable - : CooperativeSaveMode::kNotAvailable; - cooperative_save_mode.store(mode); +std::atomic<absl::optional<bool>> cooperative_save_present; +std::atomic<absl::optional<bool>> random_save_present; + +bool CooperativeSavePresent() { + auto present = cooperative_save_present.load(); + if (!present.has_value()) { + present = getenv("GVISOR_COOPERATIVE_SAVE_TEST") != nullptr; + cooperative_save_present.store(present); + } + return present.value(); +} + +bool RandomSavePresent() { + auto present = random_save_present.load(); + if (!present.has_value()) { + present = getenv("GVISOR_RANDOM_SAVE_TEST") != nullptr; + random_save_present.store(present); } - return mode == CooperativeSaveMode::kAvailable; + return present.value(); } std::atomic<int> save_disable; } // namespace +bool IsRunningWithSaveRestore() { + return CooperativeSavePresent() || RandomSavePresent(); +} + +void MaybeSave() { + if (CooperativeSavePresent() && save_disable.load() == 0) { + internal::DoCooperativeSave(); + } +} + DisableSave::DisableSave() { save_disable++; } DisableSave::~DisableSave() { reset(); } @@ -61,11 +73,5 @@ void DisableSave::reset() { } } -namespace internal { -bool ShouldSave() { - return CooperativeSaveEnabled() && (save_disable.load() == 0); -} -} // namespace internal - } // namespace testing } // namespace gvisor diff --git a/test/util/save_util.h b/test/util/save_util.h index bddad6120..e7218ae88 100644 --- a/test/util/save_util.h +++ b/test/util/save_util.h @@ -17,9 +17,17 @@ namespace gvisor { namespace testing { -// Disable save prevents saving while the given function executes. + +// Returns true if the environment in which the calling process is executing +// allows the test to be checkpointed and restored during execution. +bool IsRunningWithSaveRestore(); + +// May perform a co-operative save cycle. // -// This lasts the duration of the object, unless reset is called. +// errno is guaranteed to be preserved. +void MaybeSave(); + +// Causes MaybeSave to become a no-op until destroyed or reset. class DisableSave { public: DisableSave(); @@ -37,13 +45,13 @@ class DisableSave { bool reset_ = false; }; -// May perform a co-operative save cycle. +namespace internal { + +// Causes a co-operative save cycle to occur. // // errno is guaranteed to be preserved. -void MaybeSave(); +void DoCooperativeSave(); -namespace internal { -bool ShouldSave(); } // namespace internal } // namespace testing diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index fbac94912..57431b3ea 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -30,19 +30,19 @@ namespace gvisor { namespace testing { - -void MaybeSave() { - if (internal::ShouldSave()) { - int orig_errno = errno; - // We use it to trigger saving the sentry state - // when this syscall is called. - // Notice: this needs to be a valid syscall - // that is not used in any of the syscall tests. - syscall(SYS_TRIGGER_SAVE, nullptr, 0); - errno = orig_errno; - } +namespace internal { + +void DoCooperativeSave() { + int orig_errno = errno; + // We use it to trigger saving the sentry state + // when this syscall is called. + // Notice: this needs to be a valid syscall + // that is not used in any of the syscall tests. + syscall(SYS_TRIGGER_SAVE, nullptr, 0); + errno = orig_errno; } +} // namespace internal } // namespace testing } // namespace gvisor diff --git a/test/util/save_util_other.cc b/test/util/save_util_other.cc index 931af2c29..7749ded76 100644 --- a/test/util/save_util_other.cc +++ b/test/util/save_util_other.cc @@ -14,13 +14,17 @@ #ifndef __linux__ +#include "test/util/logging.h" + namespace gvisor { namespace testing { +namespace internal { -void MaybeSave() { - // Saving is never available in a non-linux environment. +void DoCooperativeSave() { + TEST_CHECK_MSG(false, "DoCooperativeSave not implemented"); } +} // namespace internal } // namespace testing } // namespace gvisor diff --git a/test/util/signal_util.h b/test/util/signal_util.h index e7b66aa51..20eebd7e4 100644 --- a/test/util/signal_util.h +++ b/test/util/signal_util.h @@ -88,7 +88,7 @@ inline void FixupFault(ucontext_t* ctx) { #elif __aarch64__ inline void Fault() { // Zero and dereference x0. - asm("mov xzr, x0\r\n" + asm("mov x0, xzr\r\n" "str xzr, [x0]\r\n" : : diff --git a/test/util/timer_util.h b/test/util/timer_util.h index 926e6632f..e389108ef 100644 --- a/test/util/timer_util.h +++ b/test/util/timer_util.h @@ -33,6 +33,11 @@ namespace gvisor { namespace testing { +// From Linux's include/uapi/asm-generic/siginfo.h. +#ifndef sigev_notify_thread_id +#define sigev_notify_thread_id _sigev_un._tid +#endif + // Returns the current time. absl::Time Now(clockid_t id); diff --git a/tools/bazel.mk b/tools/bazel.mk index 88431ce66..3a7de427f 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -26,13 +26,13 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ BUILD_ROOTS := bazel-bin/ bazel-out/ # Bazel container configuration (see below). -USER ?= gvisor -HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +USER := $(shell whoami) +HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) BUILDER_BASE := gvisor.dev/images/default BUILDER_IMAGE := gvisor.dev/images/builder -BUILDER_NAME ?= gvisor-builder-$(HASH) -DOCKER_NAME ?= gvisor-bazel-$(HASH) -DOCKER_PRIVILEGED ?= --privileged +BUILDER_NAME := gvisor-builder-$(HASH) +DOCKER_NAME := gvisor-bazel-$(HASH) +DOCKER_PRIVILEGED := --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock @@ -59,6 +59,25 @@ ifeq (true,$(shell [[ -t 0 ]] && echo true)) FULL_DOCKER_EXEC_OPTIONS += --tty endif +# Add basic UID/GID options. +# +# Note that USERADD_DOCKER and GROUPADD_DOCKER are both defined as "deferred" +# variables in Make terminology, that is they will be expanded at time of use +# and may include other variables, including those defined below. +# +# NOTE: we pass -l to useradd below because otherwise you can hit a bug +# best described here: +# https://github.com/moby/moby/issues/5419#issuecomment-193876183 +# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out +# out of disk space. +ifneq ($(UID),0) +USERADD_DOCKER += useradd -l --uid $(UID) --non-unique --no-create-home \ + --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && +endif +ifneq ($(GID),0) +GROUPADD_DOCKER += groupadd --gid $(GID) --non-unique $(USER) && +endif + # Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" @@ -91,19 +110,12 @@ ifneq (,$(BAZEL_CONFIG)) OPTIONS += --config=$(BAZEL_CONFIG) endif -# NOTE: we pass -l to useradd below because otherwise you can hit a bug -# best described here: -# https://github.com/moby/moby/issues/5419#issuecomment-193876183 -# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out -# out of disk space. bazel-image: load-default @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ $(BUILDER_BASE) \ - sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ - $(GROUPADD_DOCKER) \ - useradd -l --uid $(UID) --non-unique --no-create-home \ - --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + sh -c "$(GROUPADD_DOCKER) \ + $(USERADD_DOCKER) \ if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) @docker rm -f $(BUILDER_NAME) diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl index d388346a5..661c9727e 100644 --- a/tools/bazeldefs/go.bzl +++ b/tools/bazeldefs/go.bzl @@ -94,10 +94,10 @@ def go_rule(rule, implementation, **kwargs): toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) -def go_test_library(target): - if hasattr(target.attr, "embed") and len(target.attr.embed) > 0: - return target.attr.embed[0] - return None +def go_embed_libraries(target): + if hasattr(target.attr, "embed"): + return target.attr.embed + return [] def go_context(ctx, goos = None, goarch = None, std = False): """Extracts a standard Go context struct. diff --git a/tools/bigquery/BUILD b/tools/bigquery/BUILD index 2b0062a63..1cea9e1c9 100644 --- a/tools/bigquery/BUILD +++ b/tools/bigquery/BUILD @@ -9,5 +9,8 @@ go_library( visibility = [ "//:sandbox", ], - deps = ["@com_google_cloud_go_bigquery//:go_default_library"], + deps = [ + "@com_google_cloud_go_bigquery//:go_default_library", + "@org_golang_google_api//option:go_default_library", + ], ) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 5f1a882de..544af3876 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -25,22 +25,30 @@ import ( "time" bq "cloud.google.com/go/bigquery" + "google.golang.org/api/option" ) -// Benchmark is the top level structure of recorded benchmark data. BigQuery +// Suite is the top level structure for a benchmark run. BigQuery // will infer the schema from this. +type Suite struct { + Name string `bq:"name"` + Conditions []*Condition `bq:"conditions"` + Benchmarks []*Benchmark `bq:"benchmarks"` + Official bool `bq:"official"` + Timestamp time.Time `bq:"timestamp"` +} + +// Benchmark represents an individual benchmark in a suite. type Benchmark struct { Name string `bq:"name"` Condition []*Condition `bq:"condition"` - Timestamp time.Time `bq:"timestamp"` - Official bool `bq:"official"` Metric []*Metric `bq:"metric"` - Metadata *Metadata `bq:"metadata"` } -// Condition represents qualifiers for the benchmark. For example: +// Condition represents qualifiers for the benchmark or suite. For example: // Get_Pid/1/real_time would have Benchmark Name "Get_Pid" with "1" -// and "real_time" parameters as conditions. +// and "real_time" parameters as conditions. Suite conditions include +// information such as the CL number and platform name. type Condition struct { Name string `bq:"name"` Value string `bq:"value"` @@ -53,19 +61,9 @@ type Metric struct { Sample float64 `bq:"sample"` } -// Metadata about this benchmark. -type Metadata struct { - CL string `bq:"changelist"` - IterationID string `bq:"iteration_id"` - PendingCL string `bq:"pending_cl"` - Workflow string `bq:"workflow"` - Platform string `bq:"platform"` - Gofer string `bq:"gofer"` -} - // InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. -func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) error { - client, err := bq.NewClient(ctx, projectID) +func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string, opts []option.ClientOption) error { + client, err := bq.NewClient(ctx, projectID, opts...) if err != nil { return fmt.Errorf("failed to initialize client on project %s: %v", projectID, err) } @@ -77,7 +75,7 @@ func InitBigQuery(ctx context.Context, projectID, datasetID, tableID string) err } table := dataset.Table(tableID) - schema, err := bq.InferSchema(Benchmark{}) + schema, err := bq.InferSchema(Suite{}) if err != nil { return fmt.Errorf("failed to infer schema: %v", err) } @@ -107,26 +105,35 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { } // NewBenchmark initializes a new benchmark. -func NewBenchmark(name string, iters int, official bool) *Benchmark { +func NewBenchmark(name string, iters int) *Benchmark { return &Benchmark{ - Name: name, - Timestamp: time.Now().UTC(), - Official: official, - Metric: make([]*Metric, 0), + Name: name, + Metric: make([]*Metric, 0), + } +} + +// NewSuite initializes a new Suite. +func NewSuite(name string, official bool) *Suite { + return &Suite{ + Name: name, + Timestamp: time.Now().UTC(), + Benchmarks: make([]*Benchmark, 0), + Conditions: make([]*Condition, 0), + Official: official, } } // SendBenchmarks sends the slice of benchmarks to the BigQuery dataset/table. -func SendBenchmarks(ctx context.Context, benchmarks []*Benchmark, projectID, datasetID, tableID string) error { - client, err := bq.NewClient(ctx, projectID) +func SendBenchmarks(ctx context.Context, suite *Suite, projectID, datasetID, tableID string, opts []option.ClientOption) error { + client, err := bq.NewClient(ctx, projectID, opts...) if err != nil { return fmt.Errorf("failed to initialize client on project: %s: %v", projectID, err) } defer client.Close() uploader := client.Dataset(datasetID).Table(tableID).Uploader() - if err = uploader.Put(ctx, benchmarks); err != nil { - return fmt.Errorf("failed to upload benchmarks to proejct %s, table %s.%s: %v", projectID, datasetID, tableID, err) + if err = uploader.Put(ctx, suite); err != nil { + return fmt.Errorf("failed to upload benchmarks %s to project %s, table %s.%s: %v", suite.Name, projectID, datasetID, tableID, err) } return nil diff --git a/tools/defs.bzl b/tools/defs.bzl index bb291c512..d75e40863 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -86,8 +86,10 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, ** ) nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = kwargs.get("srcs", []), - library = ":" + name + "_nogo_library", + deps = [":" + name + "_nogo_library"], + tags = ["nogo"], ) def calculate_sets(srcs): @@ -218,8 +220,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F if nogo: nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = all_srcs, - library = ":" + name, + deps = [":" + name], + tags = ["nogo"], ) if marshal: @@ -255,8 +259,10 @@ def go_test(name, nogo = True, **kwargs): if nogo: nogo_test( name = name + "_nogo", + config = "//:nogo_config", srcs = kwargs.get("srcs", []), - library = ":" + name, + deps = [":" + name], + tags = ["nogo"], ) def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): diff --git a/tools/github/nogo/BUILD b/tools/github/nogo/BUILD index 0633eaf19..19b7eec4d 100644 --- a/tools/github/nogo/BUILD +++ b/tools/github/nogo/BUILD @@ -10,7 +10,7 @@ go_library( "//tools/github:__subpackages__", ], deps = [ - "//tools/nogo/util", + "//tools/nogo", "@com_github_google_go_github_v28//github:go_default_library", ], ) diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go index b2bc63459..27ab1b8eb 100644 --- a/tools/github/nogo/nogo.go +++ b/tools/github/nogo/nogo.go @@ -24,7 +24,7 @@ import ( "time" "github.com/google/go-github/github" - "gvisor.dev/gvisor/tools/nogo/util" + "gvisor.dev/gvisor/tools/nogo" ) // FindingsPoster is a simple wrapper around the GitHub api. @@ -35,7 +35,7 @@ type FindingsPoster struct { dryRun bool startTime time.Time - findings map[util.Finding]struct{} + findings map[nogo.Finding]struct{} client *github.Client } @@ -47,7 +47,7 @@ func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun commit: commit, dryRun: dryRun, startTime: time.Now(), - findings: make(map[util.Finding]struct{}), + findings: make(map[nogo.Finding]struct{}), client: client, } } @@ -63,7 +63,7 @@ func (p *FindingsPoster) Walk(paths []string) error { if !strings.HasSuffix(filename, ".findings") || info.IsDir() { return nil } - findings, err := util.ExtractFindingsFromFile(filename) + findings, err := nogo.ExtractFindingsFromFile(filename) if err != nil { return err } @@ -86,7 +86,7 @@ func (p *FindingsPoster) Post() error { if p.dryRun { for finding, _ := range p.findings { // Pretty print, so that this is useful for debugging. - fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Path, finding.Line, finding.Message) + fmt.Printf("%s: (%s+%d) %s\n", finding.Category, finding.Position.Filename, finding.Position.Line, finding.Message) } return nil } @@ -115,12 +115,13 @@ func (p *FindingsPoster) Post() error { } annotationLevel := "failure" // Always. for finding, _ := range p.findings { + title := string(finding.Category) opts.Output.Annotations = append(opts.Output.Annotations, &github.CheckRunAnnotation{ - Path: &finding.Path, - StartLine: &finding.Line, - EndLine: &finding.Line, + Path: &finding.Position.Filename, + StartLine: &finding.Position.Line, + EndLine: &finding.Position.Line, Message: &finding.Message, - Title: &finding.Category, + Title: &title, AnnotationLevel: &annotationLevel, }) } diff --git a/tools/github/reviver/github.go b/tools/github/reviver/github.go index a95df0fb6..c4b624f2a 100644 --- a/tools/github/reviver/github.go +++ b/tools/github/reviver/github.go @@ -121,13 +121,24 @@ func (b *GitHubBugger) Activate(todo *Todo) (bool, error) { return true, nil } +var issuePrefixes = []string{ + "gvisor.dev/issue/", + "gvisor.dev/issues/", +} + // parseIssueNo parses the issue number out of the issue url. +// +// 0 is returned if url does not correspond to an issue. func parseIssueNo(url string) (int, error) { - const prefix = "gvisor.dev/issue/" - // First check if I can handle the TODO. - idStr := strings.TrimPrefix(url, prefix) - if len(url) == len(idStr) { + var idStr string + for _, p := range issuePrefixes { + if str := strings.TrimPrefix(url, p); len(str) < len(url) { + idStr = str + break + } + } + if len(idStr) == 0 { return 0, nil } diff --git a/tools/github/reviver/reviver_test.go b/tools/github/reviver/reviver_test.go index a9fb1f9f1..851306c9d 100644 --- a/tools/github/reviver/reviver_test.go +++ b/tools/github/reviver/reviver_test.go @@ -33,6 +33,15 @@ func TestProcessLine(t *testing.T) { }, }, { + line: "// TODO(foobar.com/issues/123): comment, bla. blabla.", + want: &Todo{ + Issue: "foobar.com/issues/123", + Locations: []Location{ + {Comment: "comment, bla. blabla."}, + }, + }, + }, + { line: "// FIXME(b/123): internal bug", want: &Todo{ Issue: "b/123", diff --git a/tools/go_branch.sh b/tools/go_branch.sh index e5c060024..71d036b12 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -14,23 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -xeo pipefail +set -xeou pipefail # Discovery the package name from the go.mod file. -declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) -declare -r origpwd=$(pwd) -declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") +declare module origpwd othersrc +module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2) +origpwd=$(pwd) +othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE") +readonly module origpwd othersrc + # Check that gopath has been built. -declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" -if ! [ -d "${gopath_dir}" ]; then +declare gopath_dir +gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}" +readonly gopath_dir +if ! [[ -d "${gopath_dir}" ]]; then echo "No gopath directory found; build the :gopath target." >&2 exit 1 fi # Create a temporary working directory, and ensure that this directory and all # subdirectories are cleaned up upon exit. -declare -r tmp_dir=$(mktemp -d) +declare tmp_dir +tmp_dir=$(mktemp -d) +readonly tmp_dir finish() { cd # Leave tmp_dir. rm -rf "${tmp_dir}" @@ -38,21 +45,27 @@ finish() { trap finish EXIT # Record the current working commit. -declare -r head=$(git describe --always) +declare head +head=$(git describe --always) +readonly head # We expect to have an existing go branch that we will use as the basis for this # commit. That branch may be empty, but it must exist. We search for this branch # using the local branch, the "origin" branch, and other remotes, in order. git fetch --all -declare -r go_branch=$( \ +declare go_branch +go_branch=$( \ git show-ref --hash refs/heads/go || \ git show-ref --hash refs/remotes/origin/go || \ git show-ref --hash go | head -n 1 \ ) +readonly go_branch # Clone the current repository to the temporary directory, and check out the # current go_branch directory. We move to the new repository for convenience. -declare -r repo_orig="$(pwd)" +declare repo_orig +repo_orig="$(pwd)" +readonly repo_orig declare -r repo_new="${tmp_dir}/repository" git clone . "${repo_new}" cd "${repo_new}" @@ -68,8 +81,8 @@ git checkout -b go "${go_branch}" # # N.B. The git behavior changed at some point and the relevant flag was added # to allow for override, so try the only behavior first then pass the flag. -git merge --no-commit --strategy ours ${head} || \ - git merge --allow-unrelated-histories --no-commit --strategy ours ${head} +git merge --no-commit --strategy ours "${head}" || \ + git merge --allow-unrelated-histories --no-commit --strategy ours "${head}" # Normalize the permissions on the old branch. Note that they should be # normalized if constructed by this tool, but we do so before the rsync. @@ -96,7 +109,7 @@ EOF # There are a few solitary files that can get left behind due to the way bazel # constructs the gopath target. Note that we don't find all Go files here # because they may correspond to unused templates, etc. -declare -ar binaries=( "runsc" "shim/v1" "shim/v2" ) +declare -ar binaries=( "runsc" "shim/v1" "shim/v2" "webhook" ) for target in "${binaries[@]}"; do mkdir -p "${target}" cp "${repo_orig}/${target}"/*.go "${target}/" @@ -109,7 +122,11 @@ find . -type f -exec chmod 0644 {} \; find . -type d -exec chmod 0755 {} \; # Update the current working set and commit. -git add . && git commit -m "Merge ${head} (automated)" +# If the current working commit has already been committed to the remote go +# branch, then we have nothing to commit here. So allow empty commit. This can +# occur when this script is run parallely (via pull_request and push events) +# and the push workflow finishes before the pull_request workflow can run this. +git add . && git commit --allow-empty -m "Merge ${head} (automated)" # Push the branch back to the original repository. git remote add orig "${repo_orig}" && git push -f orig go:go diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go index 90d3aa1e0..370650e46 100644 --- a/tools/go_generics/imports.go +++ b/tools/go_generics/imports.go @@ -48,7 +48,7 @@ func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[st // Create a new entry in the used map. path := imports[importName] if path == "" { - return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig) + return fmt.Errorf("unknown path to package '%s', used in '%s'", importName, orig) } m = &importedPackage{ @@ -72,7 +72,7 @@ func convertExpression(s string, imports mapValue, used map[string]*importedPack // Parse the expression in the input string. expr, err := parser.ParseExpr(s) if err != nil { - return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err) + return "", fmt.Errorf("unable to parse \"%s\": %v", s, err) } // Go through the AST and update references. diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go index 7f62b0a2b..df14ae98e 100644 --- a/tools/go_marshal/test/escape/escape.go +++ b/tools/go_marshal/test/escape/escape.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package escape contains test cases for escape analysis. package escape import ( diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index d9e9f341b..e7e3ed74a 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -161,7 +161,7 @@ type TestArray [sizeA]int32 // +marshal type TestArray2 [sizeA * sizeB]int32 -// TestArray2 is a newtype on an array with a simple arithmetic expression of +// TestArray3 is a newtype on an array with a simple arithmetic expression of // mixed constants and literals for the array length. // // +marshal diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 3c6be3339..12b8b597c 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -20,20 +20,14 @@ nogo_stdlib( visibility = ["//visibility:public"], ) -sh_binary( - name = "gentest", - srcs = ["gentest.sh"], - visibility = ["//visibility:public"], -) - go_library( name = "nogo", srcs = [ + "analyzers.go", "build.go", "config.go", - "matchers.go", + "findings.go", "nogo.go", - "register.go", ], nogo = False, visibility = ["//:sandbox"], diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go new file mode 100644 index 000000000..b919bc2f8 --- /dev/null +++ b/tools/nogo/analyzers.go @@ -0,0 +1,131 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/gob" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/asmdecl" + "golang.org/x/tools/go/analysis/passes/assign" + "golang.org/x/tools/go/analysis/passes/atomic" + "golang.org/x/tools/go/analysis/passes/bools" + "golang.org/x/tools/go/analysis/passes/buildtag" + "golang.org/x/tools/go/analysis/passes/cgocall" + "golang.org/x/tools/go/analysis/passes/composite" + "golang.org/x/tools/go/analysis/passes/copylock" + "golang.org/x/tools/go/analysis/passes/errorsas" + "golang.org/x/tools/go/analysis/passes/httpresponse" + "golang.org/x/tools/go/analysis/passes/loopclosure" + "golang.org/x/tools/go/analysis/passes/lostcancel" + "golang.org/x/tools/go/analysis/passes/nilfunc" + "golang.org/x/tools/go/analysis/passes/nilness" + "golang.org/x/tools/go/analysis/passes/printf" + "golang.org/x/tools/go/analysis/passes/shadow" + "golang.org/x/tools/go/analysis/passes/shift" + "golang.org/x/tools/go/analysis/passes/stdmethods" + "golang.org/x/tools/go/analysis/passes/stringintconv" + "golang.org/x/tools/go/analysis/passes/structtag" + "golang.org/x/tools/go/analysis/passes/tests" + "golang.org/x/tools/go/analysis/passes/unmarshal" + "golang.org/x/tools/go/analysis/passes/unreachable" + "golang.org/x/tools/go/analysis/passes/unsafeptr" + "golang.org/x/tools/go/analysis/passes/unusedresult" + "honnef.co/go/tools/staticcheck" + "honnef.co/go/tools/stylecheck" + + "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checkunsafe" +) + +// AllAnalyzers is a list of all available analyzers. +var AllAnalyzers = []*analysis.Analyzer{ + asmdecl.Analyzer, + assign.Analyzer, + atomic.Analyzer, + bools.Analyzer, + buildtag.Analyzer, + cgocall.Analyzer, + composite.Analyzer, + copylock.Analyzer, + errorsas.Analyzer, + httpresponse.Analyzer, + loopclosure.Analyzer, + lostcancel.Analyzer, + nilfunc.Analyzer, + nilness.Analyzer, + printf.Analyzer, + shift.Analyzer, + stdmethods.Analyzer, + stringintconv.Analyzer, + shadow.Analyzer, + structtag.Analyzer, + tests.Analyzer, + unmarshal.Analyzer, + unreachable.Analyzer, + unsafeptr.Analyzer, + unusedresult.Analyzer, + checkescape.Analyzer, + checkunsafe.Analyzer, +} + +// EscapeAnalyzers is a list of escape-related analyzers. +var EscapeAnalyzers = []*analysis.Analyzer{ + checkescape.EscapeAnalyzer, +} + +func register(all []*analysis.Analyzer) { + // Register all fact types. + // + // N.B. This needs to be done recursively, because there may be + // analyzers in the Requires list that do not appear explicitly above. + registered := make(map[*analysis.Analyzer]struct{}) + var registerOne func(*analysis.Analyzer) + registerOne = func(a *analysis.Analyzer) { + if _, ok := registered[a]; ok { + return + } + + // Register dependencies. + for _, da := range a.Requires { + registerOne(da) + } + + // Register local facts. + for _, f := range a.FactTypes { + gob.Register(f) + } + + registered[a] = struct{}{} // Done. + } + for _, a := range all { + registerOne(a) + } +} + +func init() { + // Add all staticcheck analyzers. + for _, a := range staticcheck.Analyzers { + AllAnalyzers = append(AllAnalyzers, a) + } + // Add all stylecheck analyzers. + for _, a := range stylecheck.Analyzers { + AllAnalyzers = append(AllAnalyzers, a) + } + + // Register lists. + register(AllAnalyzers) + register(EscapeAnalyzers) +} diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 55d34760f..d173cff1f 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -20,22 +20,6 @@ import ( "os" ) -var ( - // internalPrefix is the internal path prefix. Note that this is not - // special, as paths should be passed relative to the repository root - // and should not have any special prefix applied. - internalPrefix = fmt.Sprintf("^") - - // internalDefault is applied when no paths are provided. - internalDefault = fmt.Sprintf("%s/.*", notPath("external")) - - // generatedPrefix is a regex for generated files. - generatedPrefix = "^(.*/)?(bazel-genfiles|bazel-out|bazel-bin)/" - - // externalPrefix is external workspace packages. - externalPrefix = "^external/" -) - // findStdPkg needs to find the bundled standard library packages. func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) { if path == "C" { diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD index 21ba2c306..e18483a18 100644 --- a/tools/nogo/check/BUILD +++ b/tools/nogo/check/BUILD @@ -2,8 +2,6 @@ load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) -# Note that the check binary must be public, since an aspect may be applied -# across lots of different rules in different repositories. go_binary( name = "check", srcs = ["main.go"], diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go index 3828edf3a..69bdfe502 100644 --- a/tools/nogo/check/main.go +++ b/tools/nogo/check/main.go @@ -16,9 +16,99 @@ package main import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "gvisor.dev/gvisor/tools/nogo" ) +var ( + packageFile = flag.String("package", "", "package configuration file (in JSON format)") + stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") + findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") + factsOutput = flag.String("facts", "", "output file for facts (optional)") + escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") +) + +func loadConfig(file string, config interface{}) interface{} { + // Load the configuration. + f, err := os.Open(file) + if err != nil { + log.Fatalf("unable to open configuration %q: %v", file, err) + } + defer f.Close() + dec := json.NewDecoder(f) + dec.DisallowUnknownFields() + if err := dec.Decode(config); err != nil { + log.Fatalf("unable to decode configuration: %v", err) + } + return config +} + func main() { - nogo.Main() + // Parse all flags. + flag.Parse() + + var ( + findings []nogo.Finding + factData []byte + err error + ) + + // Check & load the configuration. + if *packageFile != "" && *stdlibFile != "" { + log.Fatalf("unable to perform stdlib and package analysis; provide only one!") + } + + // Run the configuration. + if *stdlibFile != "" { + // Perform basic analysis. + c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig) + findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers) + + } else if *packageFile != "" { + // Perform basic analysis. + c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig) + findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil) + + // Do we need to do escape analysis? + if *escapesOutput != "" { + escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil) + if err != nil { + log.Fatalf("error performing escape analysis: %v", err) + } + if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil { + log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err) + } + } + } else { + log.Fatalf("please provide at least one of package or stdlib!") + } + + // Check that analysis was successful. + if err != nil { + log.Fatalf("error performing analysis: %v", err) + } + + // Save facts. + if *factsOutput != "" { + if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { + log.Fatalf("error saving findings to %q: %v", *factsOutput, err) + } + } + + // Write all findings. + if *findingsOutput != "" { + if err := nogo.WriteFindingsToFile(findings, *findingsOutput); err != nil { + log.Fatalf("error writing findings to %q: %v", *findingsOutput, err) + } + } else { + for _, finding := range findings { + fmt.Fprintf(os.Stdout, "%s\n", finding.String()) + } + } } diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 0853f03cf..2fea5b3e1 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -15,544 +15,247 @@ package nogo import ( - "golang.org/x/tools/go/analysis" - "golang.org/x/tools/go/analysis/passes/asmdecl" - "golang.org/x/tools/go/analysis/passes/assign" - "golang.org/x/tools/go/analysis/passes/atomic" - "golang.org/x/tools/go/analysis/passes/bools" - "golang.org/x/tools/go/analysis/passes/buildtag" - "golang.org/x/tools/go/analysis/passes/cgocall" - "golang.org/x/tools/go/analysis/passes/composite" - "golang.org/x/tools/go/analysis/passes/copylock" - "golang.org/x/tools/go/analysis/passes/errorsas" - "golang.org/x/tools/go/analysis/passes/httpresponse" - "golang.org/x/tools/go/analysis/passes/loopclosure" - "golang.org/x/tools/go/analysis/passes/lostcancel" - "golang.org/x/tools/go/analysis/passes/nilfunc" - "golang.org/x/tools/go/analysis/passes/nilness" - "golang.org/x/tools/go/analysis/passes/printf" - "golang.org/x/tools/go/analysis/passes/shadow" - "golang.org/x/tools/go/analysis/passes/shift" - "golang.org/x/tools/go/analysis/passes/stdmethods" - "golang.org/x/tools/go/analysis/passes/stringintconv" - "golang.org/x/tools/go/analysis/passes/structtag" - "golang.org/x/tools/go/analysis/passes/tests" - "golang.org/x/tools/go/analysis/passes/unmarshal" - "golang.org/x/tools/go/analysis/passes/unreachable" - "golang.org/x/tools/go/analysis/passes/unsafeptr" - "golang.org/x/tools/go/analysis/passes/unusedresult" - "honnef.co/go/tools/staticcheck" - "honnef.co/go/tools/stylecheck" - - "gvisor.dev/gvisor/tools/checkescape" - "gvisor.dev/gvisor/tools/checkunsafe" + "fmt" + "regexp" ) -var analyzerConfig = map[*analysis.Analyzer]matcher{ - // Standard analyzers. - asmdecl.Analyzer: alwaysMatches(), - assign.Analyzer: externalExcluded( - ".*gazelle/walk/walk.go", // False positive. - ), - atomic.Analyzer: alwaysMatches(), - bools.Analyzer: alwaysMatches(), - buildtag.Analyzer: alwaysMatches(), - cgocall.Analyzer: alwaysMatches(), - composite.Analyzer: and( - disableMatches(), // Disabled for now. - resultExcluded{ - "Object_", - "Range{", - }, - ), - copylock.Analyzer: internalMatches(), // Common external issues (e.g. protos). - errorsas.Analyzer: alwaysMatches(), - httpresponse.Analyzer: alwaysMatches(), - loopclosure.Analyzer: alwaysMatches(), - lostcancel.Analyzer: internalMatches(), // Common external issues. - nilfunc.Analyzer: alwaysMatches(), - nilness.Analyzer: and( - internalMatches(), // Common "tautological checks". - internalExcluded( - "pkg/sentry/platform/kvm/kvm_test.go", // Intentional. - "tools/bigquery/bigquery.go", // False positive. - ), - ), - printf.Analyzer: alwaysMatches(), - shift.Analyzer: alwaysMatches(), - stdmethods.Analyzer: internalMatches(), // Common external issues (e.g. methods named "Write"). - stringintconv.Analyzer: and( - internalExcluded(), - externalExcluded( - ".*protobuf/.*.go", // Bad conversions. - ".*flate/huffman_bit_writer.go", // Bad conversion. +// GroupName is a named group. +type GroupName string + +// AnalyzerName is a named analyzer. +type AnalyzerName string + +// Group represents a named collection of files. +type Group struct { + // Name is the short name for the group. + Name GroupName `yaml:"name"` + + // Regex matches all full paths in the group. + Regex string `yaml:"regex"` + regex *regexp.Regexp `yaml:"-"` + + // Default determines the default group behavior. + // + // If Default is true, all Analyzers are enabled for this + // group. Otherwise, Analyzers must be individually enabled + // by specifying a (possible empty) ItemConfig for the group + // in the AnalyzerConfig. + Default bool `yaml:"default"` +} + +func (g *Group) compile() error { + r, err := regexp.Compile(g.Regex) + if err != nil { + return err + } + g.regex = r + return nil +} + +// ItemConfig is an (Analyzer,Group) configuration. +type ItemConfig struct { + // Exclude are analyzer exclusions. + // + // Exclude is a list of regular expressions. If the corresponding + // Analyzer emits a Finding for which Finding.Position.String() + // matches a regular expression in Exclude, the finding will not + // be reported. + Exclude []string `yaml:"exclude,omitempty"` + exclude []*regexp.Regexp `yaml:"-"` + + // Suppress are analyzer suppressions. + // + // Suppress is a list of regular expressions. If the corresponding + // Analyzer emits a Finding for which Finding.Message matches a regular + // expression in Suppress, the finding will not be reported. + Suppress []string `yaml:"suppress,omitempty"` + suppress []*regexp.Regexp `yaml:"-"` +} + +func compileRegexps(ss []string, rs *[]*regexp.Regexp) error { + *rs = make([]*regexp.Regexp, 0, len(ss)) + for _, s := range ss { + r, err := regexp.Compile(s) + if err != nil { + return err + } + *rs = append(*rs, r) + } + return nil +} + +func (i *ItemConfig) compile() error { + if i == nil { + // This may be nil if nothing is included in the + // item configuration. That's fine, there's nothing + // to compile and nothing to exclude & suppress. + return nil + } + if err := compileRegexps(i.Exclude, &i.exclude); err != nil { + return fmt.Errorf("in exclude: %w", err) + } + if err := compileRegexps(i.Suppress, &i.suppress); err != nil { + return fmt.Errorf("in suppress: %w", err) + } + return nil +} + +func (i *ItemConfig) merge(other *ItemConfig) { + i.Exclude = append(i.Exclude, other.Exclude...) + i.Suppress = append(i.Suppress, other.Suppress...) +} + +func (i *ItemConfig) shouldReport(fullPos, msg string) bool { + if i == nil { + // See above. + return true + } + for _, r := range i.exclude { + if r.MatchString(fullPos) { + return false + } + } + for _, r := range i.suppress { + if r.MatchString(msg) { + return false + } + } + return true +} + +// AnalyzerConfig is the configuration for a single analyzers. +// +// This map is keyed by individual Group names, to allow for different +// configurations depending on what Group the file belongs to. +type AnalyzerConfig map[GroupName]*ItemConfig + +func (a AnalyzerConfig) compile() error { + for name, gc := range a { + if err := gc.compile(); err != nil { + return fmt.Errorf("invalid group %q: %v", name, err) + } + } + return nil +} + +func (a AnalyzerConfig) merge(other AnalyzerConfig) { + // Merge all the groups. + for name, gc := range other { + old, ok := a[name] + if !ok || old == nil { + a[name] = gc // Not configured in a. + continue + } + old.merge(gc) + } +} + +func (a AnalyzerConfig) shouldReport(groupConfig *Group, fullPos, msg string) bool { + gc, ok := a[groupConfig.Name] + if !ok { + return groupConfig.Default + } + + // Note that if a section appears for a particular group + // for a particular analyzer, then it will now be enabled, + // and the group default no longer applies. + return gc.shouldReport(fullPos, msg) +} + +// Config is a nogo configuration. +type Config struct { + // Prefixes defines a set of regular expressions that + // are standard "prefixes", so that files can be grouped + // and specific rules applied to individual groups. + Groups []Group `yaml:"groups"` - // Runtime internal violations. - ".*reflect/value.go", - ".*encoding/xml/xml.go", - ".*runtime/pprof/internal/profile/proto.go", - ".*fmt/scan.go", - ".*go/types/conversions.go", - ".*golang.org/x/net/dns/dnsmessage/message.go", - ), - ), - shadow.Analyzer: disableMatches(), // Disabled for now. - structtag.Analyzer: internalMatches(), // External not subject to rules. - tests.Analyzer: alwaysMatches(), - unmarshal.Analyzer: alwaysMatches(), - unreachable.Analyzer: internalMatches(), - unsafeptr.Analyzer: and( - internalMatches(), - internalExcluded( - ".*_test.go", // Exclude tests. - "pkg/flipcall/.*_unsafe.go", // Special case. - "pkg/gohacks/gohacks_unsafe.go", // Special case. - "pkg/sentry/fs/fsutil/host_file_mapper_unsafe.go", // Special case. - "pkg/sentry/platform/kvm/bluepill_unsafe.go", // Special case. - "pkg/sentry/platform/kvm/machine_unsafe.go", // Special case. - "pkg/sentry/platform/ring0/pagetables/allocator_unsafe.go", // Special case. - "pkg/sentry/platform/safecopy/safecopy_unsafe.go", // Special case. - "pkg/sentry/vfs/mount_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/stub_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/switchto_google_unsafe.go", // Special case. - "pkg/sentry/platform/systrap/sysmsg_thread_unsafe.go", // Special case. - ), - ), - unusedresult.Analyzer: alwaysMatches(), + // Global is the global analyzer config. + Global AnalyzerConfig `yaml:"global"` - // Internal analyzers: external packages not subject. - checkescape.Analyzer: internalMatches(), - checkunsafe.Analyzer: internalMatches(), + // Analyzers are individual analyzer configurations. The + // key for each analyzer is the name of the analyzer. The + // value is either a boolean (enable/disable), or a map to + // the groups above. + Analyzers map[AnalyzerName]AnalyzerConfig `yaml:"analyzers"` } -func init() { - staticMatcher := and( - // Only match internal, non-generated files. - internalMatches(), - generatedExcluded(), +// Merge merges two configurations. +func (c *Config) Merge(other *Config) { + // Merge all groups. + for _, g := range other.Groups { + // Is there a matching group? If yes, we just delete + // it. This will preserve the order provided in the + // overriding file, even if it differs. + for i := 0; i < len(c.Groups); i++ { + if g.Name == c.Groups[i].Name { + copy(c.Groups[i:], c.Groups[i+1:]) + c.Groups = c.Groups[:len(c.Groups)-1] + break + } + } + c.Groups = append(c.Groups, g) + } - // We use ALL_CAPS for system definitions, - // which are common enough in the code base - // that we shouldn't annotate exceptions. - // - // Same story for underscores. - resultExcluded([]string{ - "should not use ALL_CAPS in Go names", - "should not use underscores in Go names", - }), + // Merge global configurations. + c.Global.merge(other.Global) - // Exclude existing matches. - internalExcluded( - "pkg/abi/linux/fuse.go:22", - "pkg/abi/linux/fuse.go:25", - "pkg/abi/linux/socket.go:113", - "pkg/abi/linux/tty.go:73", - "pkg/bpf/decoder.go:112", - "pkg/cpuid/cpuid_x86.go:675", - "pkg/eventchannel/event.go:193", - "pkg/eventchannel/event.go:27", - "pkg/eventchannel/event_test.go:22", - "pkg/eventchannel/rate.go:19", - "pkg/gohacks/gohacks_unsafe.go:33", - "pkg/log/json.go:30", - "pkg/log/log.go:359", - "pkg/merkletree/merkletree.go:230", - "pkg/merkletree/merkletree.go:243", - "pkg/merkletree/merkletree.go:249", - "pkg/merkletree/merkletree.go:266", - "pkg/merkletree/merkletree.go:355", - "pkg/merkletree/merkletree.go:369", - "pkg/metric/metric_test.go:20", - "pkg/p9/p9test/client_test.go:687", - "pkg/p9/transport_test.go:196", - "pkg/pool/pool.go:15", - "pkg/refs/refcounter.go:510", - "pkg/refs/refcounter_test.go:169", - "pkg/refs_vfs2/refs.go:16", - "pkg/safemem/block_unsafe.go:89", - "pkg/seccomp/seccomp.go:82", - "pkg/segment/test/set_functions.go:15", - "pkg/sentry/arch/signal.go:166", - "pkg/sentry/arch/signal.go:171", - "pkg/sentry/control/pprof.go:196", - "pkg/sentry/devices/memdev/full.go:58", - "pkg/sentry/devices/memdev/null.go:59", - "pkg/sentry/devices/memdev/random.go:68", - "pkg/sentry/devices/memdev/zero.go:86", - "pkg/sentry/fdimport/fdimport.go:15", - "pkg/sentry/fs/attr.go:257", - "pkg/sentry/fsbridge/fs.go:116", - "pkg/sentry/fsbridge/vfs.go:124", - "pkg/sentry/fsbridge/vfs.go:70", - "pkg/sentry/fs/copy_up.go:365", - "pkg/sentry/fs/copy_up_test.go:65", - "pkg/sentry/fs/dev/net_tun.go:161", - "pkg/sentry/fs/dev/net_tun.go:63", - "pkg/sentry/fs/dev/null.go:97", - "pkg/sentry/fs/dirent_cache.go:64", - "pkg/sentry/fs/file_overlay.go:327", - "pkg/sentry/fs/file_overlay.go:524", - "pkg/sentry/fs/filetest/filetest.go:55", - "pkg/sentry/fs/filetest/filetest.go:60", - "pkg/sentry/fs/fs.go:77", - "pkg/sentry/fs/fsutil/file.go:290", - "pkg/sentry/fs/fsutil/file.go:346", - "pkg/sentry/fs/fsutil/host_file_mapper.go:105", - "pkg/sentry/fs/fsutil/inode_cached.go:676", - "pkg/sentry/fs/fsutil/inode_cached.go:772", - "pkg/sentry/fs/gofer/attr.go:120", - "pkg/sentry/fs/gofer/fifo.go:33", - "pkg/sentry/fs/gofer/inode.go:410", - "pkg/sentry/fsimpl/devpts/devpts.go:110", - "pkg/sentry/fsimpl/devpts/devpts.go:246", - "pkg/sentry/fsimpl/devpts/devpts.go:50", - "pkg/sentry/fsimpl/devpts/master.go:110", - "pkg/sentry/fsimpl/devpts/master.go:55", - "pkg/sentry/fsimpl/devpts/replica.go:113", - "pkg/sentry/fsimpl/devpts/replica.go:57", - "pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54", - "pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97", - "pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92", - "pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44", - "pkg/sentry/fsimpl/ext/disklayout/inode_new.go:91", - "pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93", - "pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66", - "pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53", - "pkg/sentry/fsimpl/eventfd/eventfd.go:268", - "pkg/sentry/fsimpl/ext/directory.go:163", - "pkg/sentry/fsimpl/ext/directory.go:164", - "pkg/sentry/fsimpl/ext/extent_file.go:142", - "pkg/sentry/fsimpl/ext/extent_file.go:143", - "pkg/sentry/fsimpl/ext/ext.go:105", - "pkg/sentry/fsimpl/ext/filesystem.go:287", - "pkg/sentry/fsimpl/ext/regular_file.go:153", - "pkg/sentry/fsimpl/ext/symlink.go:113", - "pkg/sentry/fsimpl/fuse/connection_control.go:194", - "pkg/sentry/fsimpl/fuse/dev.go:387", - "pkg/sentry/fsimpl/fuse/dev_test.go:318", - "pkg/sentry/fsimpl/fuse/fusefs.go:102", - "pkg/sentry/fsimpl/fuse/read_write.go:129", - "pkg/sentry/fsimpl/fuse/request_response.go:71", - "pkg/sentry/fsimpl/gofer/directory.go:135", - "pkg/sentry/fsimpl/gofer/filesystem.go:679", - "pkg/sentry/fsimpl/gofer/gofer.go:1694", - "pkg/sentry/fsimpl/gofer/gofer.go:276", - "pkg/sentry/fsimpl/gofer/regular_file.go:81", - "pkg/sentry/fsimpl/gofer/special_file.go:141", - "pkg/sentry/fsimpl/host/host.go:184", - "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50", - "pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90", - "pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273", - "pkg/sentry/fsimpl/kernfs/filesystem.go:247", - "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320", - "pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497", - "pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52", - "pkg/sentry/fsimpl/overlay/directory.go:119", - "pkg/sentry/fsimpl/overlay/filesystem.go:527", - "pkg/sentry/fsimpl/overlay/non_directory.go:152", - "pkg/sentry/fsimpl/overlay/overlay.go:115", - "pkg/sentry/fsimpl/overlay/overlay.go:719", - "pkg/sentry/fsimpl/pipefs/pipefs.go:74", - "pkg/sentry/fsimpl/proc/filesystem.go:52", - "pkg/sentry/fsimpl/proc/filesystem.go:81", - "pkg/sentry/fsimpl/proc/subtasks.go:126", - "pkg/sentry/fsimpl/proc/subtasks.go:189", - "pkg/sentry/fsimpl/proc/task_fds.go:168", - "pkg/sentry/fsimpl/proc/task_fds.go:228", - "pkg/sentry/fsimpl/proc/task_fds.go:301", - "pkg/sentry/fsimpl/proc/task_fds.go:318", - "pkg/sentry/fsimpl/proc/task_fds.go:67", - "pkg/sentry/fsimpl/proc/task_files.go:112", - "pkg/sentry/fsimpl/proc/task_files.go:158", - "pkg/sentry/fsimpl/proc/task_files.go:259", - "pkg/sentry/fsimpl/proc/task_files.go:285", - "pkg/sentry/fsimpl/proc/task_files.go:305", - "pkg/sentry/fsimpl/proc/task_files.go:384", - "pkg/sentry/fsimpl/proc/task_files.go:403", - "pkg/sentry/fsimpl/proc/task_files.go:428", - "pkg/sentry/fsimpl/proc/task_files.go:691", - "pkg/sentry/fsimpl/proc/task_files.go:770", - "pkg/sentry/fsimpl/proc/task_files.go:797", - "pkg/sentry/fsimpl/proc/task_files.go:828", - "pkg/sentry/fsimpl/proc/task_files.go:879", - "pkg/sentry/fsimpl/proc/task_files.go:910", - "pkg/sentry/fsimpl/proc/task_files.go:961", - "pkg/sentry/fsimpl/proc/task.go:127", - "pkg/sentry/fsimpl/proc/task.go:193", - "pkg/sentry/fsimpl/proc/task_net.go:134", - "pkg/sentry/fsimpl/proc/task_net.go:475", - "pkg/sentry/fsimpl/proc/task_net.go:491", - "pkg/sentry/fsimpl/proc/task_net.go:508", - "pkg/sentry/fsimpl/proc/task_net.go:665", - "pkg/sentry/fsimpl/proc/task_net.go:715", - "pkg/sentry/fsimpl/proc/task_net.go:779", - "pkg/sentry/fsimpl/proc/tasks_files.go:113", - "pkg/sentry/fsimpl/proc/tasks_files.go:388", - "pkg/sentry/fsimpl/proc/tasks.go:232", - "pkg/sentry/fsimpl/proc/tasks_sys.go:145", - "pkg/sentry/fsimpl/proc/tasks_sys.go:181", - "pkg/sentry/fsimpl/proc/tasks_sys.go:239", - "pkg/sentry/fsimpl/proc/tasks_sys.go:291", - "pkg/sentry/fsimpl/proc/tasks_sys.go:375", - "pkg/sentry/fsimpl/signalfd/signalfd.go:124", - "pkg/sentry/fsimpl/signalfd/signalfd.go:15", - "pkg/sentry/fsimpl/signalfd/signalfd.go:126", - "pkg/sentry/fsimpl/sockfs/sockfs.go:36", - "pkg/sentry/fsimpl/sockfs/sockfs.go:79", - "pkg/sentry/fsimpl/sys/kcov.go:49", - "pkg/sentry/fsimpl/sys/kcov.go:99", - "pkg/sentry/fsimpl/sys/sys.go:118", - "pkg/sentry/fsimpl/sys/sys.go:56", - "pkg/sentry/fsimpl/testutil/testutil.go:257", - "pkg/sentry/fsimpl/testutil/testutil.go:260", - "pkg/sentry/fsimpl/timerfd/timerfd.go:87", - "pkg/sentry/fsimpl/tmpfs/directory.go:112", - "pkg/sentry/fsimpl/tmpfs/filesystem.go:195", - "pkg/sentry/fsimpl/tmpfs/regular_file.go:226", - "pkg/sentry/fsimpl/tmpfs/regular_file.go:346", - "pkg/sentry/fsimpl/tmpfs/tmpfs.go:103", - "pkg/sentry/fsimpl/tmpfs/tmpfs.go:733", - "pkg/sentry/fsimpl/verity/filesystem.go:490", - "pkg/sentry/fsimpl/verity/verity.go:156", - "pkg/sentry/fsimpl/verity/verity.go:629", - "pkg/sentry/fsimpl/verity/verity.go:672", - "pkg/sentry/fs/mount.go:162", - "pkg/sentry/fs/mount.go:256", - "pkg/sentry/fs/mount_overlay.go:144", - "pkg/sentry/fs/mounts.go:432", - "pkg/sentry/fs/proc/exec_args.go:104", - "pkg/sentry/fs/proc/exec_args.go:73", - "pkg/sentry/fs/proc/fds.go:269", - "pkg/sentry/fs/proc/loadavg.go:33", - "pkg/sentry/fs/proc/meminfo.go:39", - "pkg/sentry/fs/proc/mounts.go:193", - "pkg/sentry/fs/proc/mounts.go:84", - "pkg/sentry/fs/proc/net.go:125", - "pkg/sentry/fs/proc/proc.go:146", - "pkg/sentry/fs/proc/proc.go:204", - "pkg/sentry/fs/proc/seqfile/seqfile.go:210", - "pkg/sentry/fs/proc/sys.go:146", - "pkg/sentry/fs/proc/sys.go:43", - "pkg/sentry/fs/proc/sys_net.go:113", - "pkg/sentry/fs/proc/sys_net.go:205", - "pkg/sentry/fs/proc/sys_net.go:233", - "pkg/sentry/fs/proc/sys_net.go:307", - "pkg/sentry/fs/proc/sys_net.go:335", - "pkg/sentry/fs/proc/sys_net.go:446", - "pkg/sentry/fs/proc/sys_net.go:456", - "pkg/sentry/fs/proc/sys_net.go:89", - "pkg/sentry/fs/proc/task.go:170", - "pkg/sentry/fs/proc/task.go:322", - "pkg/sentry/fs/proc/task.go:427", - "pkg/sentry/fs/proc/task.go:467", - "pkg/sentry/fs/proc/task.go:500", - "pkg/sentry/fs/proc/task.go:784", - "pkg/sentry/fs/proc/task.go:839", - "pkg/sentry/fs/proc/task.go:920", - "pkg/sentry/fs/proc/uid_gid_map.go:108", - "pkg/sentry/fs/proc/uid_gid_map.go:79", - "pkg/sentry/fs/proc/uptime.go:75", - "pkg/sentry/fs/ramfs/dir.go:447", - "pkg/sentry/fs/tmpfs/inode_file.go:436", - "pkg/sentry/fs/tmpfs/inode_file.go:537", - "pkg/sentry/fs/tty/dir.go:313", - "pkg/sentry/fs/tty/master.go:131", - "pkg/sentry/fs/tty/master.go:91", - "pkg/sentry/fs/tty/replica.go:116", - "pkg/sentry/fs/tty/replica.go:88", - "pkg/sentry/kernel/auth/id_map.go:269", - "pkg/sentry/kernel/fasync/fasync.go:67", - "pkg/sentry/kernel/kcov.go:209", - "pkg/sentry/kernel/kcov.go:223", - "pkg/sentry/kernel/kernel.go:343", - "pkg/sentry/kernel/kernel.go:368", - "pkg/sentry/kernel/pipe/node_test.go:112", - "pkg/sentry/kernel/pipe/node_test.go:119", - "pkg/sentry/kernel/pipe/node_test.go:130", - "pkg/sentry/kernel/pipe/node_test.go:137", - "pkg/sentry/kernel/pipe/node_test.go:149", - "pkg/sentry/kernel/pipe/node_test.go:150", - "pkg/sentry/kernel/pipe/node_test.go:158", - "pkg/sentry/kernel/pipe/node_test.go:174", - "pkg/sentry/kernel/pipe/node_test.go:180", - "pkg/sentry/kernel/pipe/node_test.go:193", - "pkg/sentry/kernel/pipe/node_test.go:202", - "pkg/sentry/kernel/pipe/node_test.go:205", - "pkg/sentry/kernel/pipe/node_test.go:216", - "pkg/sentry/kernel/pipe/node_test.go:219", - "pkg/sentry/kernel/pipe/node_test.go:271", - "pkg/sentry/kernel/pipe/node_test.go:290", - "pkg/sentry/kernel/pipe/pipe_test.go:93", - "pkg/sentry/kernel/pipe/reader_writer.go:65", - "pkg/sentry/kernel/posixtimer.go:157", - "pkg/sentry/kernel/ptrace.go:218", - "pkg/sentry/kernel/semaphore/semaphore.go:323", - "pkg/sentry/kernel/sessions.go:123", - "pkg/sentry/kernel/sessions.go:508", - "pkg/sentry/kernel/signal_handlers.go:57", - "pkg/sentry/kernel/task_context.go:72", - "pkg/sentry/kernel/task_exit.go:67", - "pkg/sentry/kernel/task_sched.go:255", - "pkg/sentry/kernel/task_sched.go:280", - "pkg/sentry/kernel/task_sched.go:323", - "pkg/sentry/kernel/task_stop.go:192", - "pkg/sentry/kernel/thread_group.go:530", - "pkg/sentry/kernel/timekeeper.go:316", - "pkg/sentry/kernel/vdso.go:106", - "pkg/sentry/kernel/vdso.go:118", - "pkg/sentry/memmap/memmap.go:103", - "pkg/sentry/memmap/memmap.go:163", - "pkg/sentry/mm/address_space.go:42", - "pkg/sentry/mm/address_space.go:42", - "pkg/sentry/mm/aio_context.go:208", - "pkg/sentry/mm/aio_context.go:288", - "pkg/sentry/mm/pma.go:683", - "pkg/sentry/mm/special_mappable.go:80", - "pkg/sentry/platform/systrap/subprocess.go:370", - "pkg/sentry/platform/systrap/usertrap/usertrap_amd64.go:124", - "pkg/sentry/socket/control/control.go:260", - "pkg/sentry/socket/control/control.go:94", - "pkg/sentry/socket/control/control_vfs2.go:37", - "pkg/sentry/socket/hostinet/stack.go:433", - "pkg/sentry/socket/hostinet/stack.go:438", - "pkg/sentry/socket/hostinet/stack.go:444", - "pkg/sentry/socket/hostinet/stack.go:460", - "pkg/sentry/socket/netfilter/tcp_matcher.go:74", - "pkg/sentry/socket/netfilter/udp_matcher.go:71", - "pkg/sentry/socket/netlink/route/protocol.go:38", - "pkg/sentry/socket/socket.go:332", - "pkg/sentry/socket/unix/transport/connectioned.go:394", - "pkg/sentry/socket/unix/transport/connectionless.go:152", - "pkg/sentry/socket/unix/transport/unix.go:436", - "pkg/sentry/socket/unix/transport/unix.go:490", - "pkg/sentry/socket/unix/transport/unix.go:685", - "pkg/sentry/socket/unix/transport/unix.go:795", - "pkg/sentry/syscalls/linux/sys_sem.go:62", - "pkg/sentry/syscalls/linux/sys_time.go:189", - "pkg/sentry/usage/cpu.go:42", - "pkg/sentry/vfs/anonfs.go:302", - "pkg/sentry/vfs/anonfs.go:99", - "pkg/sentry/vfs/dentry.go:214", - "pkg/sentry/vfs/epoll.go:168", - "pkg/sentry/vfs/epoll.go:314", - "pkg/sentry/vfs/file_description.go:549", - "pkg/sentry/vfs/file_description_impl_util.go:304", - "pkg/sentry/vfs/file_description_impl_util.go:412", - "pkg/sentry/vfs/filesystem.go:76", - "pkg/sentry/vfs/lock.go:15", - "pkg/sentry/vfs/lock.go:47", - "pkg/sentry/vfs/memxattr/xattr.go:37", - "pkg/sentry/vfs/mount.go:510", - "pkg/sentry/vfs/mount.go:667", - "pkg/sentry/vfs/mount_test.go:106", - "pkg/sentry/vfs/mount_test.go:160", - "pkg/sentry/vfs/mount_test.go:215", - "pkg/sentry/vfs/mount_unsafe.go:153", - "pkg/sentry/vfs/resolving_path.go:228", - "pkg/sentry/vfs/vfs.go:897", - "pkg/shim/runsc/runsc.go:16", - "pkg/shim/runsc/utils.go:16", - "pkg/shim/v1/proc/deleted_state.go:16", - "pkg/shim/v1/proc/exec.go:16", - "pkg/shim/v1/proc/exec_state.go:16", - "pkg/shim/v1/proc/init.go:16", - "pkg/shim/v1/proc/init_state.go:16", - "pkg/shim/v1/proc/io.go:16", - "pkg/shim/v1/proc/process.go:16", - "pkg/shim/v1/proc/types.go:16", - "pkg/shim/v1/proc/utils.go:16", - "pkg/shim/v1/shim/api.go:16", - "pkg/shim/v1/shim/platform.go:16", - "pkg/shim/v1/shim/service.go:16", - "pkg/shim/v1/utils/annotations.go:15", - "pkg/shim/v1/utils/utils.go:15", - "pkg/shim/v1/utils/volumes.go:15", - "pkg/shim/v2/api.go:16", - "pkg/shim/v2/epoll.go:18", - "pkg/shim/v2/options/options.go:15", - "pkg/shim/v2/options/options.go:24", - "pkg/shim/v2/options/options.go:26", - "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16", - "pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go", // Generated: exempt all. - "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22", - "pkg/shim/v2/service.go:15", - "pkg/shim/v2/service_linux.go:18", - "pkg/state/tests/integer_test.go:23", - "pkg/state/tests/integer_test.go:28", - "pkg/sync/rwmutex_test.go:105", - "pkg/syserr/host_linux.go:35", - "pkg/tcpip/adapters/gonet/gonet_test.go:144", - "pkg/tcpip/adapters/gonet/gonet_test.go:415", - "pkg/tcpip/adapters/gonet/gonet_test.go:99", - "pkg/tcpip/buffer/view.go:238", - "pkg/tcpip/buffer/view.go:238", - "pkg/tcpip/buffer/view.go:246", - "pkg/tcpip/header/tcp.go:151", - "pkg/tcpip/link/sharedmem/pipe/pipe_test.go:493", - "pkg/tcpip/stack/iptables.go:293", - "pkg/tcpip/stack/iptables_types.go:277", - "pkg/tcpip/stack/stack.go:553", - "pkg/tcpip/stack/transport_test.go:30", - "pkg/tcpip/transport/packet/endpoint.go:126", - "pkg/tcpip/transport/raw/endpoint.go:145", - "pkg/tcpip/transport/tcp/sack_scoreboard.go:167", - "pkg/unet/unet_test.go:634", - "pkg/unet/unet_test.go:662", - "pkg/unet/unet_test.go:703", - "pkg/unet/unet_test.go:98", - "pkg/usermem/addr.go:34", - "pkg/usermem/usermem.go:171", - "pkg/usermem/usermem.go:170", - "runsc/boot/compat.go:22", - "runsc/boot/compat.go:56", - "runsc/boot/loader.go:1115", - "runsc/boot/loader.go:1120", - "runsc/cmd/checkpoint.go:151", - "runsc/config/flags.go:32", - "runsc/container/container.go:641", - "runsc/container/container.go:988", - "runsc/specutils/specutils.go:172", - "runsc/specutils/specutils.go:428", - "runsc/specutils/specutils.go:436", - "runsc/specutils/specutils.go:442", - "runsc/specutils/specutils.go:447", - "runsc/specutils/specutils.go:454", - "test/cmd/test_app/fds.go:171", - "test/iptables/filter_output.go:251", - "test/packetimpact/testbench/connections.go:77", - "tools/bigquery/bigquery.go:106", - "tools/checkescape/test1/test1.go:108", - "tools/checkescape/test1/test1.go:122", - "tools/checkescape/test1/test1.go:137", - "tools/checkescape/test1/test1.go:151", - "tools/checkescape/test1/test1.go:170", - "tools/checkescape/test1/test1.go:39", - "tools/checkescape/test1/test1.go:45", - "tools/checkescape/test1/test1.go:50", - "tools/checkescape/test1/test1.go:64", - "tools/checkescape/test1/test1.go:80", - "tools/checkescape/test1/test1.go:94", - "tools/go_generics/imports.go:51", - "tools/go_generics/imports.go:75", - "tools/go_marshal/gomarshal/generator.go:177", - "tools/go_marshal/gomarshal/generator.go:81", - "tools/go_marshal/gomarshal/generator.go:85", - "tools/go_marshal/test/escape/escape.go:15", - "tools/go_marshal/test/test.go:164", - ), - ) + // Merge all analyzer configurations. + for name, ac := range other.Analyzers { + old, ok := c.Analyzers[name] + if !ok { + c.Analyzers[name] = ac // No analyzer in original config. + continue + } + old.merge(ac) + } +} - // Add all staticcheck analyzers; internal only. - for _, a := range staticcheck.Analyzers { - analyzerConfig[a] = staticMatcher +// Compile compiles a configuration to make it useable. +func (c *Config) Compile() error { + for i := 0; i < len(c.Groups); i++ { + if err := c.Groups[i].compile(); err != nil { + return fmt.Errorf("invalid group %q: %w", c.Groups[i].Name, err) + } } - // Add all stylecheck analyzers; internal only. - for _, a := range stylecheck.Analyzers { - analyzerConfig[a] = staticMatcher + if err := c.Global.compile(); err != nil { + return fmt.Errorf("invalid global: %w", err) } + for name, ac := range c.Analyzers { + if err := ac.compile(); err != nil { + return fmt.Errorf("invalid analyzer %q: %w", name, err) + } + } + return nil } -var escapesConfig = map[*analysis.Analyzer]matcher{ - // Informational only: include all packages. - checkescape.EscapeAnalyzer: alwaysMatches(), +// ShouldReport returns true iff the finding should match the Config. +func (c *Config) ShouldReport(finding Finding) bool { + fullPos := finding.Position.String() + + // Find the matching group. + var groupConfig *Group + for i := 0; i < len(c.Groups); i++ { + if c.Groups[i].regex.MatchString(fullPos) { + groupConfig = &c.Groups[i] + break + } + } + + // If there is no group matching this path, then + // we default to accept the finding. + if groupConfig == nil { + return true + } + + // Suppress via global rule? + if !c.Global.shouldReport(groupConfig, fullPos, finding.Message) { + return false + } + + // Try the analyzer config. + ac, ok := c.Analyzers[finding.Category] + if !ok { + return groupConfig.Default + } + return ac.shouldReport(groupConfig, fullPos, finding.Message) } diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 543598b52..161ea972e 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -1,6 +1,28 @@ """Nogo rules.""" -load("//tools/bazeldefs:go.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") +load("//tools/bazeldefs:go.bzl", "go_context", "go_embed_libraries", "go_importpath", "go_rule") + +NogoConfigInfo = provider( + "information about a nogo configuration", + fields = { + "srcs": "the collection of configuration files", + }, +) + +def _nogo_config_impl(ctx): + return [NogoConfigInfo( + srcs = ctx.files.srcs, + )] + +nogo_config = rule( + implementation = _nogo_config_impl, + attrs = { + "srcs": attr.label_list( + doc = "a list of yaml files (schema defined by tool/nogo/config.go).", + allow_files = True, + ), + }, +) NogoTargetInfo = provider( "information about the Go target", @@ -20,11 +42,14 @@ nogo_target = go_rule( rule, implementation = _nogo_target_impl, attrs = { - # goarch is the build architecture. This will normally be provided by a - # select statement, but this information is propagated to other rules. - "goarch": attr.string(mandatory = True), - # goos is similarly the build operating system target. - "goos": attr.string(mandatory = True), + "goarch": attr.string( + doc = "the Go build architecture (propagated to other rules).", + mandatory = True, + ), + "goos": attr.string( + doc = "the Go OS target (propagated to other rules).", + mandatory = True, + ), }, ) @@ -81,7 +106,7 @@ NogoStdlibInfo = provider( "information for nogo analysis (standard library facts)", fields = { "facts": "serialized standard library facts", - "findings": "package findings (if relevant)", + "raw_findings": "raw package findings (if relevant)", }, ) @@ -90,7 +115,7 @@ def _nogo_stdlib_impl(ctx): nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(ctx.label.name + ".facts") - findings = ctx.actions.declare_file(ctx.label.name + ".findings") + raw_findings = ctx.actions.declare_file(ctx.label.name + ".raw_findings") config = struct( Srcs = [f.path for f in go_ctx.stdlib_srcs], GOOS = go_ctx.goos, @@ -101,15 +126,15 @@ def _nogo_stdlib_impl(ctx): ctx.actions.write(config_file, config.to_json()) ctx.actions.run( inputs = [config_file] + go_ctx.stdlib_srcs, - outputs = [facts, findings], + outputs = [facts, raw_findings], tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), executable = ctx.files._nogo_check[0], - mnemonic = "GoStandardLibraryAnalysis", + mnemonic = "NogoStandardLibraryAnalysis", progress_message = "Analyzing Go Standard Library", arguments = go_ctx.nogo_args + [ "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-stdlib=%s" % config_file.path, - "-findings=%s" % findings.path, + "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, ], ) @@ -117,7 +142,7 @@ def _nogo_stdlib_impl(ctx): # Return the stdlib facts as output. return [NogoStdlibInfo( facts = facts, - findings = findings, + raw_findings = raw_findings, )] nogo_stdlib = go_rule( @@ -148,7 +173,8 @@ NogoInfo = provider( "information for nogo analysis", fields = { "facts": "serialized package facts", - "findings": "package findings (if relevant)", + "raw_findings": "raw package findings (if relevant)", + "escapes": "escape-only findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", "srcs": "srcs (for go_test support)", @@ -174,14 +200,12 @@ def _nogo_aspect_impl(target, ctx): # If we're using the "library" attribute, then we need to aggregate the # original library sources and dependencies into this target to perform # proper type analysis. - if ctx.rule.kind == "go_test": - library = go_test_library(ctx.rule) - if library != None: - info = library[NogoInfo] - if hasattr(info, "srcs"): - srcs = srcs + info.srcs - if hasattr(info, "deps"): - deps = deps + info.deps + for embed in go_embed_libraries(ctx.rule): + info = embed[NogoInfo] + if hasattr(info, "srcs"): + srcs = srcs + info.srcs + if hasattr(info, "deps"): + deps = deps + info.deps # Start with all target files and srcs as input. inputs = target.files.to_list() + srcs @@ -214,6 +238,7 @@ def _nogo_aspect_impl(target, ctx): # Collect all info from shadow dependencies. fact_map = dict() import_map = dict() + all_raw_findings = [] for dep in deps: # There will be no file attribute set for all transitive dependencies # that are not go_library or go_binary rules, such as a proto rules. @@ -231,6 +256,9 @@ def _nogo_aspect_impl(target, ctx): import_map[info.importpath] = x_files[0] fact_map[info.importpath] = info.facts.path + # Collect all findings; duplicates are resolved at the end. + all_raw_findings.extend(info.raw_findings) + # Ensure the above are available as inputs. inputs.append(info.facts) inputs += info.binaries @@ -244,7 +272,7 @@ def _nogo_aspect_impl(target, ctx): nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(target.label.name + ".facts") - findings = ctx.actions.declare_file(target.label.name + ".findings") + raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings") escapes = ctx.actions.declare_file(target.label.name + ".escapes") config = struct( ImportPath = importpath, @@ -262,39 +290,39 @@ def _nogo_aspect_impl(target, ctx): inputs.append(config_file) ctx.actions.run( inputs = inputs, - outputs = [facts, findings, escapes], + outputs = [facts, raw_findings, escapes], tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), executable = ctx.files._nogo_check[0], - mnemonic = "GoStaticAnalysis", + mnemonic = "NogoAnalysis", progress_message = "Analyzing %s" % target.label, arguments = go_ctx.nogo_args + [ "-binary=%s" % target_objfile.path, "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-package=%s" % config_file.path, - "-findings=%s" % findings.path, + "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, "-escapes=%s" % escapes.path, ], ) + # Flatten all findings from all dependencies. + # + # This is done because all the filtering must be done at the + # top-level nogo_test to dynamically apply a configuration. + # This does not actually add any additional work here, but + # will simply propagate the full list of files. + all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings] + # Return the package facts as output. - return [ - NogoInfo( - facts = facts, - findings = findings, - importpath = importpath, - binaries = binaries, - srcs = srcs, - deps = deps, - ), - OutputGroupInfo( - # Expose all findings (should just be a single file). This can be - # used for build analysis of the nogo findings. - nogo_findings = depset([findings]), - # Expose all escape analysis findings (see above). - nogo_escapes = depset([escapes]), - ), - ] + return [NogoInfo( + facts = facts, + raw_findings = all_raw_findings, + escapes = escapes, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + )] nogo_aspect = go_rule( aspect, @@ -327,41 +355,72 @@ nogo_aspect = go_rule( def _nogo_test_impl(ctx): """Check nogo findings.""" - # Build a runner that checks the facts files. - findings = [dep[NogoInfo].findings for dep in ctx.attr.deps] - runner = ctx.actions.declare_file(ctx.label.name) + # Ensure there's a single dependency. + if len(ctx.attr.deps) != 1: + fail("nogo_test requires exactly one dep.") + raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings + escapes = ctx.attr.deps[0][NogoInfo].escapes + + # Build a step that applies the configuration. + config_srcs = ctx.attr.config[NogoConfigInfo].srcs + findings = ctx.actions.declare_file(ctx.label.name + ".findings") ctx.actions.run( - inputs = findings + ctx.files.srcs, - outputs = [runner], - tools = depset(ctx.files._gentest), - executable = ctx.files._gentest[0], - mnemonic = "Gentest", + inputs = raw_findings + ctx.files.srcs + config_srcs, + outputs = [findings], + tools = depset(ctx.files._filter), + executable = ctx.files._filter[0], + mnemonic = "GoStaticAnalysis", progress_message = "Generating %s" % ctx.label, - arguments = [runner.path] + [f.path for f in findings], + arguments = ["-input=%s" % f.path for f in raw_findings] + + ["-config=%s" % f.path for f in config_srcs] + + ["-output=%s" % findings.path], ) + + # Build a runner that checks the filtered facts. + # + # Note that this calls the filter binary without any configuration, so all + # findings will be included. But this is expected, since we've already + # filtered out everything that should not be included. + runner = ctx.actions.declare_file(ctx.label.name) + runner_content = [ + "#!/bin/bash", + "exec %s -input=%s" % (ctx.files._filter[0].short_path, findings.short_path), + "", + ] + ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) + return [DefaultInfo( + # The runner just executes the filter again, on the + # newly generated filtered findings. We still need + # the filter tool as part of our runfiles, however. + runfiles = ctx.runfiles(files = ctx.files._filter + [findings]), executable = runner, + ), OutputGroupInfo( + # Propagate the filtered filters, for consumption by + # build tooling. Note that the build tooling typically + # pays attention to the mnemoic above, so this must be + # what is expected by the tooling. + nogo_findings = depset([findings]), + # Expose all escape analysis findings (see above). + nogo_escapes = depset([escapes]), )] -_nogo_test = rule( +nogo_test = rule( implementation = _nogo_test_impl, attrs = { - # deps should have only a single element. - "deps": attr.label_list(aspects = [nogo_aspect]), - # srcs exist here only to ensure that this target is - # directly affected by changes to the source files. - "srcs": attr.label_list(allow_files = True), - "_gentest": attr.label(default = "//tools/nogo:gentest"), + "config": attr.label( + mandatory = True, + doc = "A rule of kind nogo_config.", + ), + "deps": attr.label_list( + aspects = [nogo_aspect], + doc = "Exactly one Go dependency to be analyzed.", + ), + "srcs": attr.label_list( + allow_files = True, + doc = "Relevant src files. This is ignored except to make the nogo_test directly affected by the files.", + ), + "_filter": attr.label(default = "//tools/nogo/filter:filter"), }, test = True, ) - -def nogo_test(name, srcs, library, **kwargs): - tags = kwargs.pop("tags", []) + ["nogo"] - _nogo_test( - name = name, - srcs = srcs, - deps = [library], - tags = tags, - **kwargs - ) diff --git a/tools/nogo/filter/BUILD b/tools/nogo/filter/BUILD new file mode 100644 index 000000000..e56a783e2 --- /dev/null +++ b/tools/nogo/filter/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "filter", + srcs = ["main.go"], + nogo = False, + visibility = ["//visibility:public"], + deps = [ + "//tools/nogo", + "@in_gopkg_yaml_v2//:go_default_library", + ], +) diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go new file mode 100644 index 000000000..9cf41b3b0 --- /dev/null +++ b/tools/nogo/filter/main.go @@ -0,0 +1,131 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary check is the nogo entrypoint. +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + + yaml "gopkg.in/yaml.v2" + "gvisor.dev/gvisor/tools/nogo" +) + +type stringList []string + +func (s *stringList) String() string { + return strings.Join(*s, ",") +} + +func (s *stringList) Set(value string) error { + *s = append(*s, value) + return nil +} + +var ( + inputFiles stringList + configFiles stringList + outputFile string + showConfig bool +) + +func init() { + flag.Var(&inputFiles, "input", "findings input files") + flag.StringVar(&outputFile, "output", "", "findings output file") + flag.Var(&configFiles, "config", "findings configuration files") + flag.BoolVar(&showConfig, "show-config", false, "dump configuration only") +} + +func main() { + flag.Parse() + + // Load all available findings. + var findings []nogo.Finding + for _, filename := range inputFiles { + inputFindings, err := nogo.ExtractFindingsFromFile(filename) + if err != nil { + log.Fatalf("unable to extract findings from %s: %v", filename, err) + } + findings = append(findings, inputFindings...) + } + + // Open and merge all configuations. + config := &nogo.Config{ + Global: make(nogo.AnalyzerConfig), + Analyzers: make(map[nogo.AnalyzerName]nogo.AnalyzerConfig), + } + for _, filename := range configFiles { + content, err := ioutil.ReadFile(filename) + if err != nil { + log.Fatalf("unable to read %s: %v", filename, err) + } + var newConfig nogo.Config // For current file. + if err := yaml.Unmarshal(content, &newConfig); err != nil { + log.Fatalf("unable to decode %s: %v", filename, err) + } + config.Merge(&newConfig) + if showConfig { + bytes, err := yaml.Marshal(&newConfig) + if err != nil { + log.Fatalf("error marshalling config: %v", err) + } + mergedBytes, err := yaml.Marshal(config) + if err != nil { + log.Fatalf("error marshalling config: %v", err) + } + fmt.Fprintf(os.Stdout, "Loaded configuration from %s:\n%s\n", filename, string(bytes)) + fmt.Fprintf(os.Stdout, "Merged configuration:\n%s\n", string(mergedBytes)) + } + } + if err := config.Compile(); err != nil { + log.Fatalf("error compiling config: %v", err) + } + if showConfig { + os.Exit(0) + } + + // Filter the findings (and aggregate by group). + filteredFindings := make([]nogo.Finding, 0, len(findings)) + for _, finding := range findings { + if ok := config.ShouldReport(finding); ok { + filteredFindings = append(filteredFindings, finding) + } + } + + // Write the output (if required). + // + // If the outputFile is specified, then we exit here. Otherwise, + // we continue to write to stdout and treat like a test. + if outputFile != "" { + if err := nogo.WriteFindingsToFile(filteredFindings, outputFile); err != nil { + log.Fatalf("unable to write findings: %v", err) + } + return + } + + // Treat the run as a test. + if len(filteredFindings) == 0 { + fmt.Fprintf(os.Stdout, "PASS\n") + os.Exit(0) + } + for _, finding := range filteredFindings { + fmt.Fprintf(os.Stdout, "%s\n", finding.String()) + } + os.Exit(1) +} diff --git a/tools/nogo/findings.go b/tools/nogo/findings.go new file mode 100644 index 000000000..5bd850269 --- /dev/null +++ b/tools/nogo/findings.go @@ -0,0 +1,63 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nogo + +import ( + "encoding/json" + "fmt" + "go/token" + "io/ioutil" +) + +// Finding is a single finding. +type Finding struct { + Category AnalyzerName + Position token.Position + Message string +} + +// String implements fmt.Stringer.String. +func (f *Finding) String() string { + return fmt.Sprintf("%s: %s: %s", f.Category, f.Position.String(), f.Message) +} + +// WriteFindingsToFile writes findings to a file. +func WriteFindingsToFile(findings []Finding, filename string) error { + content, err := WriteFindingsToBytes(findings) + if err != nil { + return err + } + return ioutil.WriteFile(filename, content, 0644) +} + +// WriteFindingsToBytes serializes findings as bytes. +func WriteFindingsToBytes(findings []Finding) ([]byte, error) { + return json.Marshal(findings) +} + +// ExtractFindingsFromFile loads findings from a file. +func ExtractFindingsFromFile(filename string) ([]Finding, error) { + content, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + return ExtractFindingsFromBytes(content) +} + +// ExtractFindingsFromBytes loads findings from bytes. +func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { + err = json.Unmarshal(content, &findings) + return findings, err +} diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh deleted file mode 100755 index 0a762f9f6..000000000 --- a/tools/nogo/gentest.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -# Copyright 2019 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -euo pipefail - -if [[ "$#" -lt 2 ]]; then - echo "usage: $0 <output> <findings...>" - exit 2 -fi -declare violations=0 -declare output=$1 -shift - -# Start the script. -echo "#!/bin/sh" > "${output}" - -# Read a list of findings files. -declare filename -declare line -for filename in "$@"; do - if [[ -z "${filename}" ]]; then - continue - fi - while read -r line; do - line="${line@Q}" - violations=$((${violations}+1)); - echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}" - done < "${filename}" -done - -# Show violations. -if [[ "${violations}" -eq 0 ]]; then - echo "echo -e '\\033[0;32mPASS\\033[0;31m\\033[0m'" >> "${output}" -else - echo "exit 1" >> "${output}" -fi diff --git a/tools/nogo/matchers.go b/tools/nogo/matchers.go deleted file mode 100644 index b7b73fa27..000000000 --- a/tools/nogo/matchers.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nogo - -import ( - "go/token" - "regexp" - "strings" - - "golang.org/x/tools/go/analysis" -) - -type matcher interface { - ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool -} - -// pathRegexps filters explicit paths. -type pathRegexps struct { - expr []*regexp.Regexp - - // include, if true, indicates that paths matching any regexp in expr - // match. - // - // If false, paths matching no regexps in expr match. - include bool -} - -// buildRegexps builds a list of regular expressions. -// -// This will panic on error. -func buildRegexps(prefix string, args ...string) []*regexp.Regexp { - result := make([]*regexp.Regexp, 0, len(args)) - for _, arg := range args { - result = append(result, regexp.MustCompile(prefix+arg)) - } - return result -} - -// notPath works around the lack of backtracking. -// -// It is used to construct a regular expression for non-matching components. -func notPath(name string) string { - sb := strings.Builder{} - sb.WriteString("(") - for i := range name { - if i > 0 { - sb.WriteString("|") - } - sb.WriteString(name[:i]) - sb.WriteString("[^") - sb.WriteByte(name[i]) - sb.WriteString("/][^/]*") - } - sb.WriteString(")") - return sb.String() -} - -// ShouldReport implements matcher.ShouldReport. -func (p *pathRegexps) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { - fullPos := fs.Position(d.Pos).String() - for _, path := range p.expr { - if path.MatchString(fullPos) { - return p.include - } - } - return !p.include -} - -// internalExcluded excludes specific internal paths. -func internalExcluded(paths ...string) *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(internalPrefix, paths...), - include: false, - } -} - -// excludedExcluded excludes specific external paths. -func externalExcluded(paths ...string) *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(externalPrefix, paths...), - include: false, - } -} - -// internalMatches returns a path matcher for internal packages. -func internalMatches() *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(internalPrefix, internalDefault), - include: true, - } -} - -// generatedExcluded excludes all generated code. -func generatedExcluded() *pathRegexps { - return &pathRegexps{ - expr: buildRegexps(generatedPrefix, ".*"), - include: false, - } -} - -// resultExcluded excludes explicit message contents. -type resultExcluded []string - -// ShouldReport implements matcher.ShouldReport. -func (r resultExcluded) ShouldReport(d analysis.Diagnostic, _ *token.FileSet) bool { - for _, str := range r { - if strings.Contains(d.Message, str) { - return false - } - } - return true // Not excluded. -} - -// andMatcher is a composite matcher. -type andMatcher struct { - all []matcher -} - -// ShouldReport implements matcher.ShouldReport. -func (a *andMatcher) ShouldReport(d analysis.Diagnostic, fs *token.FileSet) bool { - for _, m := range a.all { - if !m.ShouldReport(d, fs) { - return false - } - } - return true -} - -// and is a syntactic convension for andMatcher. -func and(ms ...matcher) *andMatcher { - return &andMatcher{ - all: ms, - } -} - -// anyMatcher matches everything. -type anyMatcher struct{} - -// ShouldReport implements matcher.ShouldReport. -func (anyMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { - return true -} - -// alwaysMatches returns an anyMatcher instance. -func alwaysMatches() anyMatcher { - return anyMatcher{} -} - -// neverMatcher will never match. -type neverMatcher struct{} - -// ShouldReport implements matcher.ShouldReport. -func (neverMatcher) ShouldReport(analysis.Diagnostic, *token.FileSet) bool { - return false -} - -// disableMatches returns a neverMatcher instance. -func disableMatches() neverMatcher { - return neverMatcher{} -} diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index e19e3c237..779d4d6d8 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -21,7 +21,6 @@ package nogo import ( "encoding/json" "errors" - "flag" "fmt" "go/ast" "go/build" @@ -45,20 +44,20 @@ import ( "gvisor.dev/gvisor/tools/checkescape" ) -// stdlibConfig is serialized as the configuration. +// StdlibConfig is serialized as the configuration. // // This contains everything required for stdlib analysis. -type stdlibConfig struct { +type StdlibConfig struct { Srcs []string GOOS string GOARCH string Tags []string } -// packageConfig is serialized as the configuration. +// PackageConfig is serialized as the configuration. // // This contains everything required for single package analysis. -type packageConfig struct { +type PackageConfig struct { ImportPath string GoFiles []string NonGoFiles []string @@ -84,7 +83,7 @@ type saver func([]byte) error // // This is done because all stdlib data is stored together, and we don't want // to load this data many times over. -func (c *packageConfig) factLoader() (loader, error) { +func (c *PackageConfig) factLoader() (loader, error) { allFacts := make(map[string][]byte) if c.StdlibFacts != "" { data, err := ioutil.ReadFile(c.StdlibFacts) @@ -114,7 +113,7 @@ func (c *packageConfig) factLoader() (loader, error) { // shouldInclude indicates whether the file should be included. // // NOTE: This does only basic parsing of tags. -func (c *packageConfig) shouldInclude(path string) (bool, error) { +func (c *PackageConfig) shouldInclude(path string) (bool, error) { ctx := build.Default ctx.GOOS = c.GOOS ctx.GOARCH = c.GOARCH @@ -128,7 +127,7 @@ func (c *packageConfig) shouldInclude(path string) (bool, error) { // files, and the facts. Note that this importer implementation will always // pass when a given package is not available. type importer struct { - *packageConfig + *PackageConfig fset *token.FileSet cache map[string]*types.Package lastErr error @@ -185,14 +184,14 @@ func (i *importer) Import(path string) (*types.Package, error) { // ErrSkip indicates the package should be skipped. var ErrSkip = errors.New("skipped") -// checkStdlib checks the standard library. +// CheckStdlib checks the standard library. // // This constructs a synthetic package configuration for each library in the -// standard library sources, and call checkPackage repeatedly. +// standard library sources, and call CheckPackage repeatedly. // // Note that not all parts of the source are expected to build. We skip obvious // test files, and cmd files, which should not be dependencies. -func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) { +func CheckStdlib(config *StdlibConfig, analyzers []*analysis.Analyzer) (allFindings []Finding, facts []byte, err error) { if len(config.Srcs) == 0 { return nil, nil, nil } @@ -225,7 +224,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } // Aggregate all files by directory. - packages := make(map[string]*packageConfig) + packages := make(map[string]*PackageConfig) for _, file := range config.Srcs { if !strings.HasPrefix(file, rootSrcPrefix) { // Superflouous file. @@ -243,7 +242,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } c, ok := packages[pkg] if !ok { - c = &packageConfig{ + c = &PackageConfig{ ImportPath: pkg, GOOS: config.GOOS, GOARCH: config.GOARCH, @@ -262,7 +261,6 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str } // Closure to check a single package. - allFindings := make([]string, 0) stdlibFacts := make(map[string][]byte) stdlibErrs := make(map[string]error) var checkOne func(pkg string) error // Recursive. @@ -301,7 +299,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str }() // Run the analysis. - findings, factData, err := checkPackage(config, ac, checkOne) + findings, factData, err := CheckPackage(config, analyzers, checkOne) if err != nil { // If we can't analyze a package from the standard library, // then we skip it. It will simply not have any findings. @@ -344,7 +342,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str return allFindings, factData, nil } -// checkPackage runs all analyzers. +// CheckPackage runs all given analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. // This returns a list of matching analysis issues, or an error if the analysis @@ -352,9 +350,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str // // [1] bazelbuid/rules_go/tools/builders/nogo_main.go // [2] golang.org/x/tools/go/checker/internal/checker -func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) { +func CheckPackage(config *PackageConfig, analyzers []*analysis.Analyzer, importCallback func(string) error) (findings []Finding, factData []byte, err error) { imp := &importer{ - packageConfig: config, + PackageConfig: config, fset: token.NewFileSet(), cache: make(map[string]*types.Package), callback: importCallback, @@ -406,7 +404,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo // Register fact types and establish dependencies between analyzers. // The visit closure will execute recursively, and populate results // will all required analysis results. - diagnostics := make(map[*analysis.Analyzer][]analysis.Diagnostic) results := make(map[*analysis.Analyzer]interface{}) var visit func(*analysis.Analyzer) error // For recursion. visit = func(a *analysis.Analyzer) error { @@ -421,27 +418,25 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } } - // Prepare the matcher. - m := ac[a] - report := func(d analysis.Diagnostic) { - if m.ShouldReport(d, imp.fset) { - diagnostics[a] = append(diagnostics[a], d) - } - } - // Run the analysis. factFilter := make(map[reflect.Type]bool) for _, f := range a.FactTypes { factFilter[reflect.TypeOf(f)] = true } p := &analysis.Pass{ - Analyzer: a, - Fset: imp.fset, - Files: syntax, - Pkg: types, - TypesInfo: typesInfo, - ResultOf: results, // All results. - Report: report, + Analyzer: a, + Fset: imp.fset, + Files: syntax, + Pkg: types, + TypesInfo: typesInfo, + ResultOf: results, // All results. + Report: func(d analysis.Diagnostic) { + findings = append(findings, Finding{ + Category: AnalyzerName(a.Name), + Position: imp.fset.Position(d.Pos), + Message: d.Message, + }) + }, ImportPackageFact: facts.ImportPackageFact, ExportPackageFact: facts.ExportPackageFact, ImportObjectFact: facts.ImportObjectFact, @@ -464,7 +459,7 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } // Visit all analyzers recursively. - for a, _ := range ac { + for _, a := range analyzers { if imp.lastErr == ErrSkip { continue // No local analysis. } @@ -473,114 +468,6 @@ func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, impo } } - // Convert all diagnostics to strings. - findings := make([]string, 0, len(diagnostics)) - for a, ds := range diagnostics { - for _, d := range ds { - // Include the anlyzer name for debugability and configuration. - findings = append(findings, fmt.Sprintf("%s: %s: %s", a.Name, imp.fset.Position(d.Pos), d.Message)) - } - } - // Return all findings. - factData := facts.Encode() - return findings, factData, nil -} - -var ( - packageFile = flag.String("package", "", "package configuration file (in JSON format)") - stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") - findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") - factsOutput = flag.String("facts", "", "output file for facts (optional)") - escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") -) - -func loadConfig(file string, config interface{}) interface{} { - // Load the configuration. - f, err := os.Open(file) - if err != nil { - log.Fatalf("unable to open configuration %q: %v", file, err) - } - defer f.Close() - dec := json.NewDecoder(f) - dec.DisallowUnknownFields() - if err := dec.Decode(config); err != nil { - log.Fatalf("unable to decode configuration: %v", err) - } - return config -} - -// Main is the entrypoint; it should be called directly from main. -// -// N.B. This package registers it's own flags. -func Main() { - // Parse all flags. - flag.Parse() - - var ( - findings []string - factData []byte - err error - ) - - // Check the configuration. - if *packageFile != "" && *stdlibFile != "" { - log.Fatalf("unable to perform stdlib and package analysis; provide only one!") - } else if *stdlibFile != "" { - // Perform basic analysis. - c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig) - findings, factData, err = checkStdlib(c, analyzerConfig) - } else if *packageFile != "" { - // Perform basic analysis. - c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig) - findings, factData, err = checkPackage(c, analyzerConfig, nil) - // Do we need to do escape analysis? - if *escapesOutput != "" { - f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - log.Fatalf("unable to open output %q: %v", *escapesOutput, err) - } - defer f.Close() - escapes, _, err := checkPackage(c, escapesConfig, nil) - if err != nil { - log.Fatalf("error performing escape analysis: %v", err) - } - for _, escape := range escapes { - fmt.Fprintf(f, "%s\n", escape) - } - } - } else { - log.Fatalf("please provide at least one of package or stdlib!") - } - - // Save facts. - if *factsOutput != "" { - if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { - log.Fatalf("error saving findings to %q: %v", *factsOutput, err) - } - } - - // Open the output file. - var w io.Writer = os.Stdout - if *findingsOutput != "" { - f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - log.Fatalf("unable to open output %q: %v", *findingsOutput, err) - } - defer f.Close() - w = f - } - - // Handle findings & errors. - if err != nil { - log.Fatalf("error checking package: %v", err) - } - if len(findings) == 0 { - return - } - - // Print findings. - for _, finding := range findings { - fmt.Fprintf(w, "%s\n", finding) - } + return findings, facts.Encode(), nil } diff --git a/tools/nogo/register.go b/tools/nogo/register.go deleted file mode 100644 index 34b173937..000000000 --- a/tools/nogo/register.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package nogo - -import ( - "encoding/gob" - "log" - - "golang.org/x/tools/go/analysis" -) - -// analyzers returns all configured analyzers. -func analyzers() (all []*analysis.Analyzer) { - for a, _ := range analyzerConfig { - all = append(all, a) - } - for a, _ := range escapesConfig { - all = append(all, a) - } - return all -} - -func init() { - // Validate basic configuration. - if err := analysis.Validate(analyzers()); err != nil { - log.Fatalf("unable to validate analyzer: %v", err) - } - - // Register all fact types. - // - // N.B. This needs to be done recursively, because there may be - // analyzers in the Requires list that do not appear explicitly above. - registered := make(map[*analysis.Analyzer]struct{}) - var register func(*analysis.Analyzer) - register = func(a *analysis.Analyzer) { - if _, ok := registered[a]; ok { - return - } - - // Regsiter dependencies. - for _, da := range a.Requires { - register(da) - } - - // Register local facts. - for _, f := range a.FactTypes { - gob.Register(f) - } - - registered[a] = struct{}{} // Done. - } - for _, a := range analyzers() { - register(a) - } -} diff --git a/tools/nogo/util/BUILD b/tools/nogo/util/BUILD deleted file mode 100644 index 7ab340b51..000000000 --- a/tools/nogo/util/BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "util", - srcs = ["util.go"], - visibility = ["//visibility:public"], -) diff --git a/tools/nogo/util/util.go b/tools/nogo/util/util.go deleted file mode 100644 index 919fec799..000000000 --- a/tools/nogo/util/util.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package util contains nogo-related utilities. -package util - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" -) - -// findingRegexp is used to parse findings. -var findingRegexp = regexp.MustCompile(`([a-zA-Z0-9_\/\.-]+): (-|([a-zA-Z0-9_\/\.-]+):([0-9]+)(:([0-9]+))?): (.*)`) - -const ( - categoryIndex = 1 - fullPathAndLineIndex = 2 - fullPathIndex = 3 - lineIndex = 4 - messageIndex = 7 -) - -// Finding is a single finding. -type Finding struct { - Category string - Path string - Line int - Message string -} - -// ExtractFindingsFromFile loads findings from a file. -func ExtractFindingsFromFile(filename string) ([]Finding, error) { - content, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err - } - return ExtractFindingsFromBytes(content) -} - -// ExtractFindingsFromBytes loads findings from bytes. -func ExtractFindingsFromBytes(content []byte) (findings []Finding, err error) { - lines := strings.Split(string(content), "\n") - for _, singleLine := range lines { - // Skip blank lines. - singleLine = strings.TrimSpace(singleLine) - if singleLine == "" { - continue - } - m := findingRegexp.FindStringSubmatch(singleLine) - if m == nil { - // We shouldn't see findings like this. - return findings, fmt.Errorf("poorly formated line: %v", singleLine) - } - if m[fullPathAndLineIndex] == "-" { - continue // No source file available. - } - // Cleanup the message. - message := m[messageIndex] - message = strings.Replace(message, " → ", "\n → ", -1) - message = strings.Replace(message, " or ", "\n or ", -1) - // Construct a new annotation. - lineNumber, _ := strconv.ParseUint(m[lineIndex], 10, 32) - findings = append(findings, Finding{ - Category: m[categoryIndex], - Path: m[fullPathIndex], - Line: int(lineNumber), - Message: message, - }) - } - return findings, nil -} diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD index 7d9c9a3fb..6932bba9a 100644 --- a/tools/parsers/BUILD +++ b/tools/parsers/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "go_test") +load("//tools:defs.bzl", "go_binary", "go_library", "go_test") package(licenses = ["notice"]) @@ -7,6 +7,7 @@ go_test( size = "small", srcs = ["go_parser_test.go"], library = ":parsers", + nogo = False, deps = [ "//tools/bigquery", "@com_github_google_go_cmp//cmp:go_default_library", @@ -19,9 +20,26 @@ go_library( srcs = [ "go_parser.go", ], + nogo = False, visibility = ["//:sandbox"], deps = [ "//test/benchmarks/tools", "//tools/bigquery", ], ) + +go_binary( + name = "parser", + testonly = 1, + srcs = [ + "parser_main.go", + "version.go", + ], + nogo = False, + x_defs = {"main.version": "{STABLE_VERSION}"}, + deps = [ + ":parsers", + "//runsc/flag", + "//tools/bigquery", + ], +) diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go index 2cf74c883..57e538149 100644 --- a/tools/parsers/go_parser.go +++ b/tools/parsers/go_parser.go @@ -27,20 +27,21 @@ import ( "gvisor.dev/gvisor/tools/bigquery" ) -// parseOutput expects golang benchmark output returns a Benchmark struct formatted for BigQuery. -func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]*bigquery.Benchmark, error) { - var benchmarks []*bigquery.Benchmark +// ParseOutput expects golang benchmark output and returns a struct formatted +// for BigQuery. +func ParseOutput(output string, name string, official bool) (*bigquery.Suite, error) { + suite := bigquery.NewSuite(name, official) lines := strings.Split(output, "\n") for _, line := range lines { - bm, err := parseLine(line, metadata, official) + bm, err := parseLine(line) if err != nil { return nil, fmt.Errorf("failed to parse line '%s': %v", line, err) } if bm != nil { - benchmarks = append(benchmarks, bm) + suite.Benchmarks = append(suite.Benchmarks, bm) } } - return benchmarks, nil + return suite, nil } // parseLine handles parsing a benchmark line into a bigquery.Benchmark. @@ -58,9 +59,8 @@ func parseOutput(output string, metadata *bigquery.Metadata, official bool) ([]* // {Name: ns/op, Unit: ns/op, Sample: 1397875880} // {Name: requests_per_second, Unit: QPS, Sample: 140 } // } -// Metadata: metadata //} -func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigquery.Benchmark, error) { +func parseLine(line string) (*bigquery.Benchmark, error) { fields := strings.Fields(line) // Check if this line is a Benchmark line. Otherwise ignore the line. @@ -78,8 +78,7 @@ func parseLine(line string, metadata *bigquery.Metadata, official bool) (*bigque return nil, fmt.Errorf("parse name/params: %v", err) } - bm := bigquery.NewBenchmark(name, iters, official) - bm.Metadata = metadata + bm := bigquery.NewBenchmark(name, iters) for _, p := range params { bm.AddCondition(p.Name, p.Value) } diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go index 36996b7c8..f0737d46b 100644 --- a/tools/parsers/go_parser_test.go +++ b/tools/parsers/go_parser_test.go @@ -94,13 +94,11 @@ func TestParseLine(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := parseLine(tc.data, nil, false) + got, err := parseLine(tc.data) if err != nil { t.Fatalf("parseLine failed with: %v", err) } - tc.want.Timestamp = got.Timestamp - if !cmp.Equal(tc.want, got, nil) { for _, c := range got.Condition { t.Logf("Cond: %+v", c) @@ -150,14 +148,14 @@ BenchmarkRuby/server_threads.5-6 1 1416003331 ns/op 0.00950 average_latency.s 46 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - bms, err := parseOutput(tc.data, nil, false) + suite, err := ParseOutput(tc.data, "", false) if err != nil { t.Fatalf("parseOutput failed: %v", err) - } else if len(bms) != tc.numBenchmarks { - t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(bms), bms) + } else if len(suite.Benchmarks) != tc.numBenchmarks { + t.Fatalf("NumBenchmarks failed want: %d got: %d %+v", tc.numBenchmarks, len(suite.Benchmarks), suite.Benchmarks) } - for _, bm := range bms { + for _, bm := range suite.Benchmarks { if len(bm.Metric) != tc.numMetrics { t.Fatalf("NumMetrics failed want: %d got: %d %+v", tc.numMetrics, len(bm.Metric), bm.Metric) } diff --git a/tools/parsers/parser_main.go b/tools/parsers/parser_main.go new file mode 100644 index 000000000..7cce69e03 --- /dev/null +++ b/tools/parsers/parser_main.go @@ -0,0 +1,135 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary parser parses Benchmark data from golang benchmarks, +// puts it into a Schema for BigQuery, and sends it to BigQuery. +// parser will also initialize a table with the Benchmarks BigQuery schema. +package main + +import ( + "context" + "fmt" + "io/ioutil" + "os" + + "gvisor.dev/gvisor/runsc/flag" + bq "gvisor.dev/gvisor/tools/bigquery" + "gvisor.dev/gvisor/tools/parsers" +) + +const ( + initString = "init" + initDescription = "initializes a new table with benchmarks schema" + parseString = "parse" + parseDescription = "parses given benchmarks file and sends it to BigQuery table." +) + +var ( + // The init command will create a new dataset/table in the given project and initialize + // the table with the schema in //tools/bigquery/bigquery.go. If the table/dataset exists + // or has been initialized, init has no effect and successfully returns. + initCmd = flag.NewFlagSet(initString, flag.ContinueOnError) + initProject = initCmd.String("project", "", "GCP project to send benchmarks.") + initDataset = initCmd.String("dataset", "", "dataset to send benchmarks data.") + initTable = initCmd.String("table", "", "table to send benchmarks data.") + + // The parse command parses benchmark data in `file` and sends it to the + // requested table. + parseCmd = flag.NewFlagSet(parseString, flag.ContinueOnError) + file = parseCmd.String("file", "", "file to parse for benchmarks") + name = parseCmd.String("suite_name", "", "name of the benchmark suite") + parseProject = parseCmd.String("project", "", "GCP project to send benchmarks.") + parseDataset = parseCmd.String("dataset", "", "dataset to send benchmarks data.") + parseTable = parseCmd.String("table", "", "table to send benchmarks data.") + official = parseCmd.Bool("official", false, "mark input data as official.") + runtime = parseCmd.String("runtime", "", "runtime used to run the benchmark") +) + +// initBenchmarks initializes a dataset/table in a BigQuery project. +func initBenchmarks(ctx context.Context) error { + return bq.InitBigQuery(ctx, *initProject, *initDataset, *initTable, nil) +} + +// parseBenchmarks parses the given file into the BigQuery schema, +// adds some custom data for the commit, and sends the data to BigQuery. +func parseBenchmarks(ctx context.Context) error { + data, err := ioutil.ReadFile(*file) + if err != nil { + return fmt.Errorf("failed to read file %s: %v", *file, err) + } + suite, err := parsers.ParseOutput(string(data), *name, *official) + if err != nil { + return fmt.Errorf("failed parse data: %v", err) + } + if len(suite.Benchmarks) < 1 { + fmt.Fprintf(os.Stderr, "Failed to find benchmarks for file: %s", *file) + return nil + } + + extraConditions := []*bq.Condition{ + { + Name: "runtime", + Value: *runtime, + }, + { + Name: "version", + Value: version, + }, + } + + suite.Official = *official + suite.Conditions = append(suite.Conditions, extraConditions...) + return bq.SendBenchmarks(ctx, suite, *parseProject, *parseDataset, *parseTable, nil) +} + +func main() { + ctx := context.Background() + switch { + // the "init" command + case len(os.Args) >= 2 && os.Args[1] == initString: + if err := initCmd.Parse(os.Args[2:]); err != nil { + fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err) + os.Exit(1) + } + if err := initBenchmarks(ctx); err != nil { + failure := "failed to initialize project: %s dataset: %s table: %s: %v\n" + fmt.Fprintf(os.Stderr, failure, *parseProject, *parseDataset, *parseTable, err) + os.Exit(1) + } + // the "parse" command. + case len(os.Args) >= 2 && os.Args[1] == parseString: + if err := parseCmd.Parse(os.Args[2:]); err != nil { + fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err) + os.Exit(1) + } + if err := parseBenchmarks(ctx); err != nil { + fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v\n", err) + os.Exit(1) + } + default: + printUsage() + os.Exit(1) + } +} + +// printUsage prints the top level usage string. +func printUsage() { + usage := `Usage: parser <command> <flags> ... + +Available commands: + %s %s + %s %s +` + fmt.Fprintf(os.Stderr, usage, initCmd.Name(), initDescription, parseCmd.Name(), parseDescription) +} diff --git a/tools/parsers/version.go b/tools/parsers/version.go new file mode 100644 index 000000000..ab9194b9d --- /dev/null +++ b/tools/parsers/version.go @@ -0,0 +1,18 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +// version is set during linking. +var version = "VERSION_MISSING" diff --git a/tools/tag_release.sh b/tools/tag_release.sh index b0bab74b4..50378065e 100755 --- a/tools/tag_release.sh +++ b/tools/tag_release.sh @@ -43,7 +43,7 @@ fi closest_commit() { while read line; do - if [[ "$line" =~ "commit " ]]; then + if [[ "$line" =~ ^"commit " ]]; then current_commit="${line#commit }" continue elif [[ "$line" =~ "PiperOrigin-RevId: " ]]; then @@ -57,7 +57,9 @@ closest_commit() { # Is the passed identifier a sha commit? if ! git show "${target_commit}" &> /dev/null; then # Extract the commit given a piper ID. - declare -r commit="$(git log | closest_commit "${target_commit}")" + commit="$(set +o pipefail; \ + git log --first-parent | closest_commit "${target_commit}")" + declare -r commit else declare -r commit="${target_commit}" fi diff --git a/webhook/BUILD b/webhook/BUILD new file mode 100644 index 000000000..33c585504 --- /dev/null +++ b/webhook/BUILD @@ -0,0 +1,28 @@ +load("//images:defs.bzl", "docker_image") +load("//tools:defs.bzl", "go_binary", "pkg_tar") + +package(licenses = ["notice"]) + +docker_image( + name = "webhook_image", + data = ":files", + statements = ['ENTRYPOINT ["/webhook"]'], +) + +# files is the full file system of the webhook container. It is simply: +# / +# └─ webhook +pkg_tar( + name = "files", + srcs = [":webhook"], + extension = "tgz", + strip_prefix = "/third_party/gvisor/webhook", +) + +go_binary( + name = "webhook", + srcs = ["main.go"], + pure = "on", + static = "on", + deps = ["//webhook/pkg/cli"], +) diff --git a/webhook/main.go b/webhook/main.go new file mode 100644 index 000000000..220016543 --- /dev/null +++ b/webhook/main.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary main serves a mutating Kubernetes webhook. +package main + +import ( + "gvisor.dev/gvisor/webhook/pkg/cli" +) + +func main() { + cli.Main() +} diff --git a/webhook/pkg/cli/BUILD b/webhook/pkg/cli/BUILD new file mode 100644 index 000000000..ac093c556 --- /dev/null +++ b/webhook/pkg/cli/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "cli", + srcs = ["cli.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//webhook/pkg/injector", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_apimachinery//pkg/util/net:go_default_library", + "@io_k8s_client_go//kubernetes:go_default_library", + "@io_k8s_client_go//rest:go_default_library", + ], +) diff --git a/webhook/pkg/cli/cli.go b/webhook/pkg/cli/cli.go new file mode 100644 index 000000000..a07d341a2 --- /dev/null +++ b/webhook/pkg/cli/cli.go @@ -0,0 +1,115 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli provides a CLI interface for a mutating Kubernetes webhook. +package cli + +import ( + "flag" + "fmt" + "net" + "net/http" + "os" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/webhook/pkg/injector" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8snet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +var ( + address = flag.String("address", "", "The ip address the admission webhook serves on. If unspecified, a public address is selected automatically.") + port = flag.Int("port", 0, "The port the admission webhook serves on.") + podLabels = flag.String("pod-namespace-labels", "", "A comma-separated namespace label selector, the admission webhook will only take effect on pods in selected namespaces, e.g. `label1,label2`.") +) + +// Main runs the webhook. +func Main() { + flag.Parse() + + if err := run(); err != nil { + log.Warningf("%v", err) + os.Exit(1) + } +} + +func run() error { + log.Infof("Starting %s\n", injector.Name) + + // Create client config. + cfg, err := rest.InClusterConfig() + if err != nil { + return fmt.Errorf("create in cluster config: %w", err) + } + + // Create clientset. + clientset, err := kubernetes.NewForConfig(cfg) + if err != nil { + return fmt.Errorf("create kubernetes client: %w", err) + } + + if err := injector.CreateConfiguration(clientset, parsePodLabels()); err != nil { + return fmt.Errorf("create webhook configuration: %w", err) + } + + if err := startWebhookHTTPS(clientset); err != nil { + return fmt.Errorf("start webhook https server: %w", err) + } + + return nil +} + +func parsePodLabels() *metav1.LabelSelector { + rv := &metav1.LabelSelector{} + for _, s := range strings.Split(*podLabels, ",") { + req := metav1.LabelSelectorRequirement{ + Key: strings.TrimSpace(s), + Operator: "Exists", + } + rv.MatchExpressions = append(rv.MatchExpressions, req) + } + return rv +} + +func startWebhookHTTPS(clientset kubernetes.Interface) error { + log.Infof("Starting HTTPS handler") + defer log.Infof("Stopping HTTPS handler") + + if *address == "" { + ip, err := k8snet.ChooseHostInterface() + if err != nil { + return fmt.Errorf("select ip address: %w", err) + } + *address = ip.String() + } + mux := http.NewServeMux() + mux.Handle("/", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + injector.Admit(w, r) + })) + server := &http.Server{ + // Listen on all addresses. + Addr: net.JoinHostPort(*address, strconv.Itoa(*port)), + TLSConfig: injector.GetTLSConfig(), + Handler: mux, + } + if err := server.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + return fmt.Errorf("start HTTPS handler: %w", err) + } + return nil +} diff --git a/webhook/pkg/injector/BUILD b/webhook/pkg/injector/BUILD new file mode 100644 index 000000000..d296981be --- /dev/null +++ b/webhook/pkg/injector/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "injector", + srcs = [ + "certs.go", + "webhook.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "@com_github_mattbaird_jsonpatch//:go_default_library", + "@io_k8s_api//admission/v1beta1:go_default_library", + "@io_k8s_api//admissionregistration/v1beta1:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_apimachinery//pkg/api/errors:go_default_library", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_client_go//kubernetes:go_default_library", + ], +) + +genrule( + name = "certs", + srcs = [":gencerts"], + outs = ["certs.go"], + cmd = "$$(cut -d ' ' -f 1 <<< \"$(locations :gencerts)\") $@", +) + +sh_binary( + name = "gencerts", + srcs = ["gencerts.sh"], +) diff --git a/webhook/pkg/injector/gencerts.sh b/webhook/pkg/injector/gencerts.sh new file mode 100755 index 000000000..f7fda4b63 --- /dev/null +++ b/webhook/pkg/injector/gencerts.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Generates the a CA cert, a server key, and a server cert signed by the CA. +# reference: +# https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/testcerts/gencerts.sh +set -euo pipefail + +# Do all the work in TMPDIR, then copy out generated code and delete TMPDIR. +declare -r OUTDIR="$(readlink -e .)" +declare -r TMPDIR="$(mktemp -d)" +cd "${TMPDIR}" +function cleanup() { + cd "${OUTDIR}" + rm -rf "${TMPDIR}" +} +trap cleanup EXIT + +declare -r CN_BASE="e2e" +declare -r CN="gvisor-injection-admission-webhook.e2e.svc" + +cat > server.conf << EOF +[req] +req_extensions = v3_req +distinguished_name = req_distinguished_name +[req_distinguished_name] +[ v3_req ] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = clientAuth, serverAuth +EOF + +declare -r OUTFILE="${TMPDIR}/certs.go" + +# We depend on OpenSSL being present. + +# Create a certificate authority. +openssl genrsa -out caKey.pem 2048 +openssl req -x509 -new -nodes -key caKey.pem -days 100000 -out caCert.pem -subj "/CN=${CN_BASE}_ca" -config server.conf + +# Create a server certificate. +openssl genrsa -out serverKey.pem 2048 +# Note the CN is the DNS name of the service of the webhook. +openssl req -new -key serverKey.pem -out server.csr -subj "/CN=${CN}" -config server.conf +openssl x509 -req -in server.csr -CA caCert.pem -CAkey caKey.pem -CAcreateserial -out serverCert.pem -days 100000 -extensions v3_req -extfile server.conf + +echo "package injector" > "${OUTFILE}" +echo "" >> "${OUTFILE}" +echo "// This file was generated using openssl by the gencerts.sh script." >> "${OUTFILE}" +for file in caKey caCert serverKey serverCert; do + DATA=$(cat "${file}.pem") + echo "" >> "${OUTFILE}" + echo "var $file = []byte(\`$DATA\`)" >> "${OUTFILE}" +done + +# Copy generated code into the output directory. +cp "${OUTFILE}" "${OUTDIR}/$1" diff --git a/webhook/pkg/injector/webhook.go b/webhook/pkg/injector/webhook.go new file mode 100644 index 000000000..614b5add7 --- /dev/null +++ b/webhook/pkg/injector/webhook.go @@ -0,0 +1,211 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package injector handles mutating webhook operations. +package injector + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "os" + + "github.com/mattbaird/jsonpatch" + "gvisor.dev/gvisor/pkg/log" + admv1beta1 "k8s.io/api/admission/v1beta1" + admregv1beta1 "k8s.io/api/admissionregistration/v1beta1" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kubeclientset "k8s.io/client-go/kubernetes" +) + +const ( + // Name is the name of the admission webhook service. The admission + // webhook must be exposed in the following service; this is mainly for + // the server certificate. + Name = "gvisor-injection-admission-webhook" + + // serviceNamespace is the namespace of the admission webhook service. + serviceNamespace = "e2e" + + fullName = Name + "." + serviceNamespace + ".svc" +) + +// CreateConfiguration creates MutatingWebhookConfiguration and registers the +// webhook admission controller with the kube-apiserver. The webhook will only +// take effect on pods in the namespaces selected by `podNsSelector`. If `podNsSelector` +// is empty, the webhook will take effect on all pods. +func CreateConfiguration(clientset kubeclientset.Interface, selector *metav1.LabelSelector) error { + fail := admregv1beta1.Fail + + config := &admregv1beta1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: Name, + }, + Webhooks: []admregv1beta1.MutatingWebhook{ + { + Name: fullName, + ClientConfig: admregv1beta1.WebhookClientConfig{ + Service: &admregv1beta1.ServiceReference{ + Name: Name, + Namespace: serviceNamespace, + }, + CABundle: caCert, + }, + Rules: []admregv1beta1.RuleWithOperations{ + { + Operations: []admregv1beta1.OperationType{ + admregv1beta1.Create, + }, + Rule: admregv1beta1.Rule{ + APIGroups: []string{"*"}, + APIVersions: []string{"*"}, + Resources: []string{"pods"}, + }, + }, + }, + FailurePolicy: &fail, + NamespaceSelector: selector, + }, + }, + } + log.Infof("Creating MutatingWebhookConfiguration %q", config.Name) + if _, err := clientset.AdmissionregistrationV1beta1().MutatingWebhookConfigurations().Create(config); err != nil { + if !apierrors.IsAlreadyExists(err) { + return fmt.Errorf("failed to create MutatingWebhookConfiguration %q: %s", config.Name, err) + } + log.Infof("MutatingWebhookConfiguration %q already exists; use the existing one", config.Name) + } + return nil +} + +// GetTLSConfig retrieves the CA cert that signed the cert used by the webhook. +func GetTLSConfig() *tls.Config { + serverCert, err := tls.X509KeyPair(serverCert, serverKey) + if err != nil { + log.Warningf("Failed to generate X509 key pair: %v", err) + os.Exit(1) + } + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } +} + +// Admit performs admission checks and mutations on Pods. +func Admit(writer http.ResponseWriter, req *http.Request) { + review := &admv1beta1.AdmissionReview{} + if err := json.NewDecoder(req.Body).Decode(review); err != nil { + log.Infof("Failed with error (%v) to decode Admit request: %+v", err, *req) + writer.WriteHeader(http.StatusBadRequest) + return + } + + log.Debugf("admitPod: %+v", review) + var err error + review.Response, err = admitPod(review.Request) + if err != nil { + log.Warningf("admitPod failed: %v", err) + review.Response = &admv1beta1.AdmissionResponse{ + Result: &metav1.Status{ + Reason: metav1.StatusReasonInvalid, + Message: err.Error(), + }, + } + sendResponse(writer, review) + return + } + + log.Debugf("Processed admission review: %+v", review) + sendResponse(writer, review) +} + +func sendResponse(writer http.ResponseWriter, response interface{}) { + b, err := json.Marshal(response) + if err != nil { + log.Warningf("Failed with error (%v) to marshal response: %+v", err, response) + writer.WriteHeader(http.StatusInternalServerError) + return + } + + writer.WriteHeader(http.StatusOK) + writer.Write(b) +} + +func admitPod(req *admv1beta1.AdmissionRequest) (*admv1beta1.AdmissionResponse, error) { + // Verify that the request is indeed a Pod. + resource := metav1.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"} + if req.Resource != resource { + return nil, fmt.Errorf("unexpected resource %+v in pod admission", req.Resource) + } + + // Decode the request into a Pod. + pod := &v1.Pod{} + if err := json.Unmarshal(req.Object.Raw, pod); err != nil { + return nil, fmt.Errorf("failed to decode pod object %s/%s", req.Namespace, req.Name) + } + + // Copy first to change it. + podCopy := pod.DeepCopy() + updatePod(podCopy) + patch, err := createPatch(req.Object.Raw, podCopy) + if err != nil { + return nil, fmt.Errorf("failed to create patch for pod %s/%s (generatedName: %s)", pod.Namespace, pod.Name, pod.GenerateName) + } + + log.Debugf("Patched pod %s/%s (generateName: %s): %+v", pod.Namespace, pod.Name, pod.GenerateName, podCopy) + patchType := admv1beta1.PatchTypeJSONPatch + return &admv1beta1.AdmissionResponse{ + Allowed: true, + Patch: patch, + PatchType: &patchType, + }, nil +} + +func updatePod(pod *v1.Pod) { + gvisor := "gvisor" + pod.Spec.RuntimeClassName = &gvisor + + // We don't run SELinux test for gvisor. + // If SELinuxOptions are specified, this is usually for volume test to pass + // on SELinux. This can be safely ignored. + if pod.Spec.SecurityContext != nil && pod.Spec.SecurityContext.SELinuxOptions != nil { + pod.Spec.SecurityContext.SELinuxOptions = nil + } + for i := range pod.Spec.Containers { + c := &pod.Spec.Containers[i] + if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil { + c.SecurityContext.SELinuxOptions = nil + } + } + for i := range pod.Spec.InitContainers { + c := &pod.Spec.InitContainers[i] + if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil { + c.SecurityContext.SELinuxOptions = nil + } + } +} + +func createPatch(old []byte, newObj interface{}) ([]byte, error) { + new, err := json.Marshal(newObj) + if err != nil { + return nil, err + } + patch, err := jsonpatch.CreatePatch(old, new) + if err != nil { + return nil, err + } + return json.Marshal(patch) +} diff --git a/website/BUILD b/website/BUILD index f3642b903..676c2b701 100644 --- a/website/BUILD +++ b/website/BUILD @@ -6,11 +6,16 @@ package(licenses = ["notice"]) docker_image( name = "website", - data = [":files"], + data = ":files", statements = [ "EXPOSE 8080/tcp", 'ENTRYPOINT ["/server"]', ], + tags = [ + "local", + "manual", + "nosandbox", + ], ) # files is the full file system of the generated container. diff --git a/website/_config.yml b/website/_config.yml index 20fbb3d2d..dc44945bc 100644 --- a/website/_config.yml +++ b/website/_config.yml @@ -37,3 +37,10 @@ authors: fvoznika: name: Fabricio Voznika email: fvoznika@google.com + ianlewis: + name: Ian Lewis + email: ianlewis@google.com + url: https://twitter.com/IanMLewis + mpratt: + name: Michael Pratt + email: mpratt@google.com diff --git a/website/_includes/byline.html b/website/_includes/byline.html index d8ae22cb0..1e808260f 100644 --- a/website/_includes/byline.html +++ b/website/_includes/byline.html @@ -5,7 +5,7 @@ By {% assign author_id=include.authors[i] %} {% assign author=site.authors[author_id] %} {% if author %} - <a href="mailto:{{ author.email }}">{{ author.name }}</a> + <a href="{% if author.url %}{{ author.url }}{% else %}mailto:{{ author.email }}{% endif %}">{{ author.name }}</a> {% else %} {{ author_id }} {% endif %} diff --git a/website/blog/2020-10-22-platform-portability.md b/website/blog/2020-10-22-platform-portability.md new file mode 100644 index 000000000..4d82940f9 --- /dev/null +++ b/website/blog/2020-10-22-platform-portability.md @@ -0,0 +1,120 @@ +# Platform Portability + +Hardware virtualization is often seen as a requirement to provide an additional +isolation layer for untrusted applications. However, hardware virtualization +requires expensive bare-metal machines or cloud instances to run safely with +good performance, increasing cost and complexity for Cloud users. gVisor, +however, takes a more flexible approach. + +One of the pillars of gVisor's architecture is portability, allowing it to run +anywhere that runs Linux. Modern Cloud-Native applications run in containers in +many different places, from bare metal to virtual machines, and can't always +rely on nested virtualization. It is important for gVisor to be able to support +the environments where you run containers. + +gVisor achieves portability through an abstraction called a _Platform_. +Platforms can have many implementations, and each implementation can cover +different environments, making use of available software or hardware features. + +## Background + +Before we can understand how gVisor achieves portability using platforms, we +should take a step back and understand how applications interact with their +host. + +Container sandboxes can provide an isolation layer between the host and +application by virtualizing one of the layers below it, including the hardware +or operating system. Many sandboxes virtualize the hardware layer by running +applications in virtual machines. gVisor takes a different approach by +virtualizing the OS layer. + +When an application is run in a normal situation the host operating system loads +the application into user memory and schedules it for execution. The operating +system scheduler eventually schedules the application to a CPU and begins +executing it. It then handles the application's requests, such as for memory and +the lifecycle of the application. gVisor virtualizes these interactions, such as +system calls, and context switching that happen between an application and OS. + +[System calls](https://en.wikipedia.org/wiki/System_call) allow applications to +ask the OS to perform some task for it. System calls look like a normal function +call in most programming languages though works a bit differently under the +hood. When an application system call is encountered some special processing +takes place to do a +[context switch](https://en.wikipedia.org/wiki/Context_switch) into kernel mode +and begin executing code in the kernel before returning a result to the +application. Context switching may happen in other situations as well. For +example, to respond to an interrupt. + +## The Platform Interface + +gVisor provides a sandbox which implements the Linux OS interface, intercepting +OS interactions such as system calls and implements them in the sandbox kernel. + +It does this to limit interactions with the host, and protect the host from an +untrusted application running in the sandbox. The Platform is the bottom layer +of gVisor which provides the environment necessary for gVisor to control and +manage applications. In general, the Platform must: + +1. Provide the ability to create and manage memory address spaces. +2. Provide execution contexts for running applications in those memory address + spaces. +3. Provide the ability to change execution context and return control to gVisor + at specific times (e.g. system call, page fault) + +This interface is conceptually simple, but very powerful. Since the Platform +interface only requires these three capabilities, it gives gVisor enough control +for it to act as the application's OS, while still allowing the use of very +different isolation technologies under the hood. You can learn more about the +Platform interface in the +[Platform Guide](https://gvisor.dev/docs/architecture_guide/platforms/). + +## Implementations of the Platform Interface + +While gVisor can make use of technologies like hardware virtualization, it +doesn't necessarily rely on any one technology to provide a similar level of +isolation. The flexibility of the Platform interface allows for implementations +that use technologies other than hardware virtualization. This allows gVisor to +run in VMs without nested virtualization, for example. By providing an +abstraction for the underlying platform, each implementation can make various +tradeoffs regarding performance or hardware requirements. + +Currently gVisor provides two gVisor Platform implementations; the Ptrace +Platform, and the KVM Platform, each using very different methods to implement +the Platform interface. + +![gVisor Platforms](../../../../../docs/architecture_guide/platforms/platforms.png "Platforms") + +The Ptrace Platform uses +[PTRACE\_SYSEMU](http://man7.org/linux/man-pages/man2/ptrace.2.html) to trap +syscalls, and uses the host for memory mapping and context switching. This +platform can run anywhere that ptrace is available, which includes most Linux +systems, VMs or otherwise. + +The KVM Platform uses virtualization, but in an unconventional way. gVisor runs +in a virtual machine but as both guest OS and VMM, and presents no virtualized +hardware layer. This provides a simpler interface that can avoid hardware +initialization for fast start up, while taking advantage of hardware +virtualization support to improve memory isolation and performance of context +switching. + +The flexibility of the Platform interface allows for a lot of room to improve +the existing KVM and ptrace platforms, as well as the ability to utilize new +methods for improving gVisor's performance or portability in future Platform +implementations. + +## Portability + +Through the Platform interface, gVisor is able to support bare metal, virtual +machines, and Cloud environments while still providing a highly secure sandbox +for running untrusted applications. This is especially important for Cloud and +Kubernetes users because it allows gVisor to run anywhere that Kubernetes can +run and provide similar experiences in multi-region, hybrid, multi-platform +environments. + +Give gVisor's open source platforms a try. Using a Platform is as easy as +providing the `--platform` flag to `runsc`. See the documentation on +[changing platforms](https://gvisor.dev/docs/user_guide/platforms/) for how to +use different platforms with Docker. We would love to hear about your experience +so come chat with us in our +[Gitter channel](https://gitter.im/gvisor/community), or send us an +[issue on Github](https://gvisor.dev/issue) if you run into any problems. diff --git a/website/blog/BUILD b/website/blog/BUILD index 865e403da..17beb721f 100644 --- a/website/blog/BUILD +++ b/website/blog/BUILD @@ -38,6 +38,17 @@ doc( permalink = "/blog/2020/09/18/containing-a-real-vulnerability/", ) +doc( + name = "platform_portability", + src = "2020-10-22-platform-portability.md", + authors = [ + "ianlewis", + "mpratt", + ], + layout = "post", + permalink = "/blog/2020/10/22/platform-portability/", +) + docs( name = "posts", deps = [ diff --git a/website/cmd/server/main.go b/website/cmd/server/main.go index c401b6abd..ac09550a9 100644 --- a/website/cmd/server/main.go +++ b/website/cmd/server/main.go @@ -29,6 +29,7 @@ var redirects = map[string]string{ // GitHub redirects. "/change": "https://github.com/google/gvisor", "/issue": "https://github.com/google/gvisor/issues", + "/issues": "https://github.com/google/gvisor/issues", "/issue/new": "https://github.com/google/gvisor/issues/new", "/pr": "https://github.com/google/gvisor/pulls", @@ -44,14 +45,16 @@ var redirects = map[string]string{ "/c/linux/amd64": "/docs/user_guide/compatibility/linux/amd64/", // Redirect for old URLs. - "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/", - "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/", - "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/", - "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/", - "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/", - "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/", - "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/", - "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/", + "/docs/user_guide/compatibility/amd64/": "/docs/user_guide/compatibility/linux/amd64/", + "/docs/user_guide/compatibility/amd64": "/docs/user_guide/compatibility/linux/amd64/", + "/docs/user_guide/kubernetes/": "/docs/user_guide/quick_start/kubernetes/", + "/docs/user_guide/kubernetes": "/docs/user_guide/quick_start/kubernetes/", + "/docs/user_guide/oci/": "/docs/user_guide/quick_start/oci/", + "/docs/user_guide/oci": "/docs/user_guide/quick_start/oci/", + "/docs/user_guide/docker/": "/docs/user_guide/quick_start/docker/", + "/docs/user_guide/docker": "/docs/user_guide/quick_start/docker/", + "/blog/2020/09/22/platform-portability": "/blog/2020/10/22/platform-portability/", + "/blog/2020/09/22/platform-portability/": "/blog/2020/10/22/platform-portability/", // Deprecated, but links continue to work. "/cl": "https://gvisor-review.googlesource.com", @@ -60,6 +63,7 @@ var redirects = map[string]string{ var prefixHelpers = map[string]string{ "change": "https://github.com/google/gvisor/commit/%s", "issue": "https://github.com/google/gvisor/issues/%s", + "issues": "https://github.com/google/gvisor/issues/%s", "pr": "https://github.com/google/gvisor/pull/%s", // Redirects to compatibility docs. |