diff options
232 files changed, 7229 insertions, 3144 deletions
@@ -93,6 +93,7 @@ go_path( "//runsc/cli", "//shim/v1/cli", "//shim/v2/cli", + "//webhook/pkg/cli", # Packages that are not dependencies of the above. "//pkg/sentry/kernel/memevent", @@ -156,12 +156,24 @@ syscall-tests: ## Run all system call tests. @$(call submake,test TARGETS="test/syscalls/...") %-runtime-tests: load-runtimes_% +ifeq ($(PARTITION),) + @$(eval PARTITION := 1) +endif +ifeq ($(TOTAL_PARTITIONS),) + @$(eval TOTAL_PARTITIONS := 1) +endif @$(call submake,install-test-runtime) - @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + @$(call submake,test-runtime OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*") %-runtime-tests_vfs2: load-runtimes_% +ifeq ($(PARTITION),) + @$(eval PARTITION := 1) +endif +ifeq ($(TOTAL_PARTITIONS),) + @$(eval TOTAL_PARTITIONS := 1) +endif @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800 --test_arg=--partition=$(PARTITION) --test_arg=--total_partitions=$(TOTAL_PARTITIONS)" TARGETS="//test/runtimes:$*") do-tests: runsc @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") @@ -210,6 +222,15 @@ iptables-tests: load-iptables @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") .PHONY: iptables-tests +# Run the iptables tests with runsc only. Useful for developing to skip runc +# testing. +iptables-runsc-tests: load-iptables + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter + @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") + @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") +.PHONY: iptables-runsc-tests + packetdrill-tests: load-packetdrill @$(call submake,install-test-runtime RUNTIME="packetdrill") @$(call submake,test-runtime RUNTIME="packetdrill" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetdrill, tests(//...))')") @@ -240,6 +261,61 @@ containerd-tests: containerd-test-1.3.4 containerd-tests: containerd-test-1.4.0-beta.0 ## +## Benchmarks. +## +## Targets to run benchmarks. See //test/benchmarks for details. +## +## common arguments: +## RUNTIME_ARGS - arguments to runsc placed in /etc/docker/daemon.json +## e.g. "--platform=ptrace" +## BENCHMARKS_PROJECT - BigQuery project to which to send data. +## BENCHMARKS_DATASET - BigQuery dataset to which to send data. +## BENCHMARKS_TABLE - BigQuery table to which to send data. +## BENCHMARKS_SUITE - name of the benchmark suite. See //tools/bigquery/bigquery.go. +## BENCHMARKS_UPLOAD - if true, upload benchmark data from the run. +## BENCHMARKS_OFFICIAL - marks the data as official. +## BENCHMARKS_PLATFORMS - platforms to run benchmarks (e.g. ptrace kvm). +## +RUNTIME_ARGS := --net-raw --platform=ptrace +BENCHMARKS_PROJECT := gvisor-benchmarks +BENCHMARKS_DATASET := kokoro +BENCHMARKS_TABLE := benchmarks +BENCHMARKS_SUITE := start +BENCHMARKS_UPLOAD := false +BENCHMARKS_OFFICIAL := false +BENCHMARKS_PLATFORMS := ptrace + +init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema +## (see //tools/bigquery/bigquery.go). If the table alread exists, this is a noop. + $(call submake, run TARGETS=//tools/parsers:parser ARGS="init --project=$(BENCHMARKS_PROJECT) \ + --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)") +.PHONY: init-benchmark-table + +benchmark-platforms: load-benchmarks-images ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. + $(call submake, run-benchmark RUNTIME="runc") + $(foreach PLATFORM,$(BENCHMARKS_PLATFORMS),\ + $(call submake,benchmark-platform RUNTIME="$(PLATFORM)" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw --vfs2") && \ + $(call submake,benchmark-platform RUNTIME="$(PLATFORM)_vfs1" RUNTIME_ARGS="--platform=$(PLATFORM) --net-raw")) +.PHONY: benchmark-platforms + +benchmark-platform: ## Installs a runtime with the given platform args. + @$(call submake,install-test-runtime ARGS="$(RUNTIME_ARGS)") + @$(call submake, run-benchmark) +.PHONY: benchmark-platform + +run-benchmark: ## Runs single benchmark and optionally sends data to BigQuery. + $(eval T := $(shell mktemp /tmp/logs.$(RUNTIME).XXXXXX)) + $(call submake,sudo TARGETS="$(TARGETS)" ARGS="--runtime=$(RUNTIME) $(ARGS)" | tee $(T)) + @if [[ "$(BENCHMARKS_UPLOAD)" == "true" ]]; then \ + @$(call submake,run TARGETS=tools/parsers:parser ARGS="parse --file=$(T) \ + --runtime=$(RUNTIME) --suite_name=$(BENCHMARKS_SUITE) \ + --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) \ + --table=$(BENCHMARKS_TABLE) --official=$(BENCHMARKS_OFFICIAL)"); \ + fi; + rm -rf $T +.PHONY: run-benchmark + +## ## Website & documentation helpers. ## ## The website is built from repository documentation and wrappers, using @@ -23,13 +23,13 @@ bazel_skylib_workspace() http_archive( name = "io_bazel_rules_go", - sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00", patch_args = ["-p1"], patches = [ # Newer versions of the rules_go rules will automatically strip test # binaries of symbols, which we don't want. "//tools:rules_go.patch", ], + sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", "https://github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", @@ -49,7 +49,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe go_rules_dependencies() -go_register_toolchains(go_version = "1.14.2") +go_register_toolchains(go_version = "1.15.2") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository") @@ -58,7 +58,7 @@ gazelle_dependencies() # The com_google_protobuf repository below would trigger downloading a older # version of org_golang_x_sys. If putting this repository statment in a place # after that of the com_google_protobuf, this statement will not work as -# expectd to download a new version of org_golang_x_sys. +# expected to download a new version of org_golang_x_sys. go_repository( name = "org_golang_x_sys", importpath = "golang.org/x/sys", @@ -222,8 +222,8 @@ go_repository( go_repository( name = "com_github_google_uuid", importpath = "github.com/google/uuid", - sum = "h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=", - version = "v1.0.0", + sum = "h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=", + version = "v1.1.1", ) go_repository( @@ -328,8 +328,8 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", - version = "v0.0.0-20201002184944-ecd9fd270d5d", + sum = "h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk=", + version = "v0.0.0-20201021000207-d49c4edd7d96", ) go_repository( @@ -349,8 +349,8 @@ go_repository( go_repository( name = "com_github_golang_protobuf", importpath = "github.com/golang/protobuf", - sum = "h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=", - version = "v1.4.2", + sum = "h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0=", + version = "v1.4.1", ) go_repository( @@ -412,7 +412,7 @@ go_repository( go_repository( name = "com_github_konsorten_go_windows_terminal_sequences", importpath = "github.com/konsorten/go-windows-terminal-sequences", - sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", + sum = "h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8=", version = "v1.0.3", ) @@ -461,8 +461,8 @@ go_repository( go_repository( name = "org_uber_go_multierr", importpath = "go.uber.org/multierr", - sum = "h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4=", - version = "v1.2.0", + sum = "h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4=", + version = "v1.6.0", ) go_repository( @@ -623,8 +623,8 @@ go_repository( go_repository( name = "com_github_google_go_cmp", importpath = "github.com/google/go-cmp", - sum = "h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k=", - version = "v0.5.1", + sum = "h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI=", + version = "v0.5.3-0.20201020212313-ab46b8bd0abd", ) go_repository( @@ -721,8 +721,8 @@ go_repository( go_repository( name = "com_github_spf13_pflag", importpath = "github.com/spf13/pflag", - sum = "h1:j8jxLbQ0+T1DFggy6XoGvyUnrJWPR/JybflPvu5rwS4=", - version = "v1.0.1-0.20171106142849-4c012f6dcd95", + sum = "h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=", + version = "v1.0.5", ) go_repository( @@ -763,15 +763,15 @@ go_repository( go_repository( name = "org_golang_google_genproto", importpath = "google.golang.org/genproto", - sum = "h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=", - version = "v0.0.0-20200526211855-cb27e3aa2013", + sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", + version = "v0.0.0-20200117163144-32f20d992d24", ) go_repository( name = "org_golang_google_protobuf", importpath = "google.golang.org/protobuf", - sum = "h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=", - version = "v1.25.1-0.20200808011614-a180de9f97d9", + sum = "h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A=", + version = "v1.25.1-0.20201020201750-d3470999428b", ) go_repository( @@ -1032,3 +1032,356 @@ go_repository( sum = "h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=", version = "v1.0.0", ) + +go_repository( + name = "com_github_azure_go_autorest_autorest", + importpath = "github.com/Azure/go-autorest/autorest", + sum = "h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs=", + version = "v0.9.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_adal", + importpath = "github.com/Azure/go-autorest/autorest/adal", + sum = "h1:q2gDruN08/guU9vAjuPWff0+QIrpH6ediguzdAzXAUU=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_date", + importpath = "github.com/Azure/go-autorest/autorest/date", + sum = "h1:YGrhWfrgtFs84+h0o46rJrlmsZtyZRg470CqAXTZaGM=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_azure_go_autorest_autorest_mocks", + importpath = "github.com/Azure/go-autorest/autorest/mocks", + sum = "h1:Ww5g4zThfD/6cLb4z6xxgeyDa7QDkizMkJKe0ysZXp0=", + version = "v0.2.0", +) + +go_repository( + name = "com_github_azure_go_autorest_logger", + importpath = "github.com/Azure/go-autorest/logger", + sum = "h1:ruG4BSDXONFRrZZJ2GUXDiUyVpayPmb1GnWeHDdaNKY=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_azure_go_autorest_tracing", + importpath = "github.com/Azure/go-autorest/tracing", + sum = "h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k=", + version = "v0.5.0", +) + +go_repository( + name = "com_github_dgrijalva_jwt_go", + importpath = "github.com/dgrijalva/jwt-go", + sum = "h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM=", + version = "v3.2.0+incompatible", +) + +go_repository( + name = "com_github_docker_spdystream", + importpath = "github.com/docker/spdystream", + sum = "h1:cenwrSVm+Z7QLSV/BsnenAOcDXdX4cMv4wP0B/5QbPg=", + version = "v0.0.0-20160310174837-449fdfce4d96", +) + +go_repository( + name = "com_github_elazarl_goproxy", + importpath = "github.com/elazarl/goproxy", + sum = "h1:p1yVGRW3nmb85p1Sh1ZJSDm4A4iKLS5QNbvUHMgGu/M=", + version = "v0.0.0-20170405201442-c4fc26588b6e", +) + +go_repository( + name = "com_github_emicklei_go_restful", + importpath = "github.com/emicklei/go-restful", + sum = "h1:H2pdYOb3KQ1/YsqVWoWNLQO+fusocsw354rqGTZtAgw=", + version = "v0.0.0-20170410110728-ff4f55a20633", +) + +go_repository( + name = "com_github_evanphx_json_patch", + importpath = "github.com/evanphx/json-patch", + sum = "h1:fUDGZCv/7iAN7u0puUVhvKCcsR6vRfwrJatElLBEf0I=", + version = "v4.2.0+incompatible", +) + +go_repository( + name = "com_github_fsnotify_fsnotify", + importpath = "github.com/fsnotify/fsnotify", + sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=", + version = "v1.4.7", +) + +go_repository( + name = "com_github_ghodss_yaml", + importpath = "github.com/ghodss/yaml", + sum = "h1:ZktWZesgun21uEDrwW7iEV1zPCGQldM2atlJZ3TdvVM=", + version = "v0.0.0-20150909031657-73d445a93680", +) + +go_repository( + name = "com_github_go_logr_logr", + importpath = "github.com/go-logr/logr", + sum = "h1:M1Tv3VzNlEHg6uyACnRdtrploV2P7wZqH8BoQMtz0cg=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_go_openapi_jsonpointer", + importpath = "github.com/go-openapi/jsonpointer", + sum = "h1:wSt/4CYxs70xbATrGXhokKF1i0tZjENLOo1ioIO13zk=", + version = "v0.0.0-20160704185906-46af16f9f7b1", +) + +go_repository( + name = "com_github_go_openapi_jsonreference", + importpath = "github.com/go-openapi/jsonreference", + sum = "h1:tF+augKRWlWx0J0B7ZyyKSiTyV6E1zZe+7b3qQlcEf8=", + version = "v0.0.0-20160704190145-13c6e3589ad9", +) + +go_repository( + name = "com_github_go_openapi_spec", + importpath = "github.com/go-openapi/spec", + sum = "h1:C1JKChikHGpXwT5UQDFaryIpDtyyGL/CR6C2kB7F1oc=", + version = "v0.0.0-20160808142527-6aced65f8501", +) + +go_repository( + name = "com_github_go_openapi_swag", + importpath = "github.com/go-openapi/swag", + sum = "h1:zP3nY8Tk2E6RTkqGYrarZXuzh+ffyLDljLxCy1iJw80=", + version = "v0.0.0-20160704191624-1d0bd113de87", +) + +go_repository( + name = "com_github_google_gofuzz", + importpath = "github.com/google/gofuzz", + sum = "h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_googleapis_gnostic", + build_file_proto_mode = "disable_global", + importpath = "github.com/googleapis/gnostic", + sum = "h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k=", + version = "v0.0.0-20170729233727-0c5108395e2d", +) + +go_repository( + name = "com_github_gophercloud_gophercloud", + importpath = "github.com/gophercloud/gophercloud", + sum = "h1:P/nh25+rzXouhytV2pUHBb65fnds26Ghl8/391+sT5o=", + version = "v0.1.0", +) + +go_repository( + name = "com_github_gregjones_httpcache", + importpath = "github.com/gregjones/httpcache", + sum = "h1:pdN6V1QBWetyv/0+wjACpqVH+eVULgEjkurDLq3goeM=", + version = "v0.0.0-20180305231024-9cad4c3443a7", +) + +go_repository( + name = "com_github_hpcloud_tail", + importpath = "github.com/hpcloud/tail", + sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_imdario_mergo", + importpath = "github.com/imdario/mergo", + sum = "h1:JboBksRwiiAJWvIYJVo46AfV+IAIKZpfrSzVKj42R4Q=", + version = "v0.3.5", +) + +go_repository( + name = "com_github_json_iterator_go", + importpath = "github.com/json-iterator/go", + sum = "h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo=", + version = "v1.1.7", +) + +go_repository( + name = "com_github_mailru_easyjson", + importpath = "github.com/mailru/easyjson", + sum = "h1:TpvdAwDAt1K4ANVOfcihouRdvP+MgAfDWwBuct4l6ZY=", + version = "v0.0.0-20160728113105-d5b7844b561a", +) + +go_repository( + name = "com_github_mattbaird_jsonpatch", + importpath = "github.com/mattbaird/jsonpatch", + sum = "h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8=", + version = "v0.0.0-20171005235357-81af80346b1a", +) + +go_repository( + name = "com_github_modern_go_concurrent", + importpath = "github.com/modern-go/concurrent", + sum = "h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=", + version = "v0.0.0-20180306012644-bacd9c7ef1dd", +) + +go_repository( + name = "com_github_modern_go_reflect2", + importpath = "github.com/modern-go/reflect2", + sum = "h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=", + version = "v1.0.1", +) + +go_repository( + name = "com_github_munnerz_goautoneg", + importpath = "github.com/munnerz/goautoneg", + sum = "h1:7PxY7LVfSZm7PEeBTyK1rj1gABdCO2mbri6GKO1cMDs=", + version = "v0.0.0-20120707110453-a547fc61f48d", +) + +go_repository( + name = "com_github_mxk_go_flowrate", + importpath = "github.com/mxk/go-flowrate", + sum = "h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus=", + version = "v0.0.0-20140419014527-cca7078d478f", +) + +go_repository( + name = "com_github_nytimes_gziphandler", + importpath = "github.com/NYTimes/gziphandler", + sum = "h1:lsxEuwrXEAokXB9qhlbKWPpo3KMLZQ5WB5WLQRW1uq0=", + version = "v0.0.0-20170623195520-56545f4a5d46", +) + +go_repository( + name = "com_github_onsi_ginkgo", + importpath = "github.com/onsi/ginkgo", + sum = "h1:VkHVNpR4iVnU8XQR6DBm8BqYjN7CRzw+xKUbVVbbW9w=", + version = "v1.8.0", +) + +go_repository( + name = "com_github_onsi_gomega", + importpath = "github.com/onsi/gomega", + sum = "h1:izbySO9zDPmjJ8rDjLvkA2zJHIo+HkYXHnf7eN7SSyo=", + version = "v1.5.0", +) + +go_repository( + name = "com_github_peterbourgon_diskv", + importpath = "github.com/peterbourgon/diskv", + sum = "h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI=", + version = "v2.0.1+incompatible", +) + +go_repository( + name = "com_github_puerkitobio_purell", + importpath = "github.com/PuerkitoBio/purell", + sum = "h1:0GoNN3taZV6QI81IXgCbxMyEaJDXMSIjArYBCYzVVvs=", + version = "v1.0.0", +) + +go_repository( + name = "com_github_puerkitobio_urlesc", + importpath = "github.com/PuerkitoBio/urlesc", + sum = "h1:JCHLVE3B+kJde7bIEo5N4J+ZbLhp0J1Fs+ulyRws4gE=", + version = "v0.0.0-20160726150825-5bd2802263f2", +) + +go_repository( + name = "com_github_spf13_afero", + importpath = "github.com/spf13/afero", + sum = "h1:5jhuqJyZCZf2JRofRvN/nIFgIWNzPa3/Vz8mYylgbWc=", + version = "v1.2.2", +) + +go_repository( + name = "in_gopkg_fsnotify_v1", + importpath = "gopkg.in/fsnotify.v1", + sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=", + version = "v1.4.7", +) + +go_repository( + name = "in_gopkg_inf_v0", + importpath = "gopkg.in/inf.v0", + sum = "h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=", + version = "v0.9.1", +) + +go_repository( + name = "in_gopkg_tomb_v1", + importpath = "gopkg.in/tomb.v1", + sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=", + version = "v1.0.0-20141024135613-dd632973f1e7", +) + +go_repository( + name = "io_k8s_api", + build_file_proto_mode = "disable_global", + importpath = "k8s.io/api", + sum = "h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs=", + version = "v0.16.13", +) + +go_repository( + name = "io_k8s_apimachinery", + build_file_proto_mode = "disable_global", + importpath = "k8s.io/apimachinery", + sum = "h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc=", + version = "v0.16.14-rc.0", +) + +go_repository( + name = "io_k8s_client_go", + importpath = "k8s.io/client-go", + sum = "h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag=", + version = "v0.16.13", +) + +go_repository( + name = "io_k8s_gengo", + importpath = "k8s.io/gengo", + sum = "h1:4s3/R4+OYYYUKptXPhZKjQ04WJ6EhQQVFdjOFvCazDk=", + version = "v0.0.0-20190128074634-0689ccc1d7d6", +) + +go_repository( + name = "io_k8s_klog", + importpath = "k8s.io/klog", + sum = "h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8=", + version = "v1.0.0", +) + +go_repository( + name = "io_k8s_kube_openapi", + importpath = "k8s.io/kube-openapi", + sum = "h1:PsbYeEz2x7ll6JYUzBEG+DT78910DDTlvn5Ma10F5/E=", + version = "v0.0.0-20200410163147-594e756bea31", +) + +go_repository( + name = "io_k8s_sigs_structured_merge_diff", + importpath = "sigs.k8s.io/structured-merge-diff", + sum = "h1:4Z09Hglb792X0kfOBBJUPFEyvVfQWrYT/l8h5EKA6JQ=", + version = "v0.0.0-20190525122527-15d366b2352e", +) + +go_repository( + name = "io_k8s_sigs_yaml", + importpath = "sigs.k8s.io/yaml", + sum = "h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs=", + version = "v1.1.0", +) + +go_repository( + name = "io_k8s_utils", + importpath = "k8s.io/utils", + sum = "h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE=", + version = "v0.0.0-20190801114015-581e00157fb1", +) @@ -29,11 +29,12 @@ require ( github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect - github.com/google/go-cmp v0.5.1 // indirect + github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd // indirect github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect github.com/hashicorp/go-multierror v1.0.0 // indirect github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect + github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/opencontainers/runc v0.1.1 // indirect @@ -43,12 +44,13 @@ require ( github.com/urfave/cli v1.22.2 // indirect github.com/vishvananda/netlink v1.0.1-0.20190930145447-2ec5bdc52b86 // indirect github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae // indirect - go.uber.org/atomic v1.7.0 // indirect - go.uber.org/multierr v1.2.0 // indirect + go.uber.org/multierr v1.6.0 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect - golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d // indirect + golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 // indirect google.golang.org/grpc v1.29.0 // indirect - google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 // indirect - gopkg.in/yaml.v2 v2.2.8 // indirect + google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b // indirect gotest.tools v2.2.0+incompatible // indirect + k8s.io/api v0.16.13 + k8s.io/apimachinery v0.16.14-rc.0 + k8s.io/client-go v0.16.13 ) @@ -13,6 +13,13 @@ cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7 cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= @@ -27,6 +34,9 @@ github.com/Microsoft/hcsshim v0.8.8/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg3 github.com/Microsoft/hcsshim v0.8.9/go.mod h1:5692vkUqntj1idxauYlpoINNKeqCiG6Sg38RRsjT5y8= github.com/Microsoft/hcsshim v0.8.10 h1:k5wTrpnVU2/xv8ZuzGkbXVd3js5zJ8RnumPo5RxiIxU= github.com/Microsoft/hcsshim v0.8.10/go.mod h1:g5uw8EV2mAlzqe94tfNBNdr89fnbD/n3HV0OhsddkmM= +github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= +github.com/PuerkitoBio/purell v1.0.0/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422 h1:8eZxmY1yvxGHzdzTEhI09npjMVGzNAdrqzruTX6jcK4= github.com/cenkalti/backoff v1.1.1-0.20190506075156-2146c9339422/go.mod h1:b6Nc7NRH5C4aCISLry0tLnTjcuTEvoiqcWDdsU0sOGM= @@ -68,8 +78,11 @@ github.com/coreos/go-systemd/v22 v22.0.0 h1:XJIw/+VlJ+87J+doOxznsAWIdmWuViOVhkQa github.com/coreos/go-systemd/v22 v22.0.0/go.mod h1:xO0FLkIi5MaZafQlIrOotqXZ90ih+1atmu1JpKERPPk= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible h1:dvc1KSkIYTVjZgHf/CTC2diTYC8PzhaA5sFISRfNVrE= github.com/docker/distribution v2.7.1-0.20190205005809-0d3efadf0154+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 h1:5AkIsnQpeL7eaqsM+Vl4Xbj5eIZFpPZZzXtNyfzzK/w= @@ -80,14 +93,25 @@ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c h1:+pKlWGMw7gf6bQ github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= +github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v4.2.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= +github.com/go-openapi/jsonpointer v0.0.0-20160704185906-46af16f9f7b1/go.mod h1:+35s3my2LFTysnkMfxsJBAMHj/DoqoB9knIWoYG/Vk0= +github.com/go-openapi/jsonreference v0.0.0-20160704190145-13c6e3589ad9/go.mod h1:W3Z9FmVs9qj+KR4zFKmDPGiLdk1D9Rlm7cyMvf57TTg= +github.com/go-openapi/spec v0.0.0-20160808142527-6aced65f8501/go.mod h1:J8+jY1nAiCcj+friV/PDoE1/3eeccG9LYBs0tYvLOWc= +github.com/go-openapi/swag v0.0.0-20160704191624-1d0bd113de87/go.mod h1:DXUve3Dpr1UfpPtxFw+EFuQ41HhCWZfha5jSVRG7C7I= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= @@ -96,9 +120,11 @@ github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1 github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= +github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -106,6 +132,7 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v0.0.0-20161109072736-4bd1920723d7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -125,11 +152,14 @@ github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= -github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd h1:pJfrTSHC+QpCQplFZqzlwihfc+0Oty0ViHPHPxXj0SI= +github.com/google/go-cmp v0.5.3-0.20201020212313-ab46b8bd0abd/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8/go.mod h1:g82e6OHbJ0WYrYeOrid1MMfHAtqjxBz+N74tfAt9KrQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= +github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -137,18 +167,28 @@ github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hf github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 h1:8nlgEAjIalk6uj/CGKCdOO8CQqTeysvcW4RFZ6HbkGM= github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d h1:7XGaL1e6bYS1yIonGp9761ExpPPV1ui0SAC59Yube9k= +github.com/googleapis/gnostic v0.0.0-20170729233727-0c5108395e2d/go.mod h1:sJBsCZ4ayReDTBIg8b9dl28c5xFWyhBTVRp3pOg5EKY= +github.com/gophercloud/gophercloud v0.1.0/go.mod h1:vxM41WHh5uqHVBMZHzuwNOHh8XEoIEcSTewFxm1c5g8= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uPribsnS6o= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/imdario/mergo v0.3.5/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.7 h1:KfgG9LzI+pYjr4xvmz/5H4FXjokeP+rlHLhv3iH62Fo= +github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1 h1:6QPYqodiu3GuPL+7mfx+NwDdp2eTkp9IfEUpgAwUN0o= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= @@ -164,8 +204,25 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a h1:+J2gw7Bw77w/fbK7wnNJJDKmw1IbWft2Ul5BzrG1Qm8= +github.com/mattbaird/jsonpatch v0.0.0-20171005235357-81af80346b1a/go.mod h1:M1qoD/MqPgTZIk0EWKB38wE28ACRfVcn+cU08jyArI0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180320133207-05fbef0ca5da/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= +github.com/onsi/gomega v1.5.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= @@ -181,9 +238,12 @@ github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNia github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/procfs v0.0.0-20180125133057-cb4147076ac7/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= @@ -196,12 +256,18 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= +github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 h1:b6uOv7YOFK0TYG7HtkIgExQo+2RdLuwRft63jn2HWj8= github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= @@ -217,12 +283,15 @@ go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= -go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200220183623-bac4c82f6975/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -247,6 +316,7 @@ golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20170114055629-f2499483f923/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -275,8 +345,11 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190209173611-3b5209105503/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -285,6 +358,7 @@ golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191022100944-742c48ecaeb7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -295,6 +369,7 @@ golang.org/x/sys v0.0.0-20200120151820-655fe14d7479/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -304,6 +379,7 @@ golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181011042414-1f849cf54d09/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -322,8 +398,8 @@ golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE= -golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96 h1:K+nJoPcImWk+ZGPHOKkDocKcQPACCz8usiCiVQYfXsk= +golang.org/x/tools v0.0.0-20201021000207-d49c4edd7d96/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -349,9 +425,8 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -359,7 +434,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -368,13 +442,18 @@ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w= -google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b h1:jEdfCm+8YTWSYgU4L7Nq0jjU+q9RxIhi0cXLTY+Ih3A= +google.golang.org/protobuf v1.25.1-0.20201020201750-d3470999428b/go.mod h1:hFxJC2f0epmp1elRCiEGJTKAWbwxZ2nvqZdHl3FQXCY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -383,4 +462,22 @@ honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +k8s.io/api v0.16.13 h1:/RE6SNxrws72vzEJsCil3WSR2T9gUlYYoRxnJyZiexs= +k8s.io/api v0.16.13/go.mod h1:QWu8UWSTiuQZMMeYjwLs6ILu5O74qKSJ0c+4vrchDxs= +k8s.io/apimachinery v0.16.13/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= +k8s.io/apimachinery v0.16.14-rc.0 h1:eUHWTe8VT+VOZVKGfSCcFZDrr9RZ8djLYGjIanaZnXc= +k8s.io/apimachinery v0.16.14-rc.0/go.mod h1:4HMHS3mDHtVttspuuhrJ1GGr/0S9B6iWYWZ57KnnZqQ= +k8s.io/client-go v0.16.13 h1:jp76b20+4h8qZBxferSAVZ6MjBEpw3F309zLmPhngag= +k8s.io/client-go v0.16.13/go.mod h1:UKvVT4cajC2iN7DCjLgT0KVY/cbY6DGdUCyRiIfws5M= +k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0= +k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= +k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= +k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8= +k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= +k8s.io/kube-openapi v0.0.0-20200410163147-594e756bea31/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E= +k8s.io/utils v0.0.0-20190801114015-581e00157fb1 h1:+ySTxfHnfzZb9ys375PXNlLhkJPLKgHajBU0N62BDvE= +k8s.io/utils v0.0.0-20190801114015-581e00157fb1/go.mod h1:sZAwmy6armz5eXlNoLmJcl4F1QuKu7sr+mFQ0byX7Ew= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +sigs.k8s.io/structured-merge-diff v0.0.0-20190525122527-15d366b2352e/go.mod h1:wWxsB5ozmmv/SG7nM11ayaAW51xMvak/t1r0CSlcokI= +sigs.k8s.io/yaml v1.1.0 h1:4A07+ZFc2wgJwo8YNlQpr1rVlgUDlxXHhPJciaPY5gs= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= diff --git a/images/README.md b/images/README.md index 9880946a6..297c7c3f3 100644 --- a/images/README.md +++ b/images/README.md @@ -41,9 +41,9 @@ All images will be tagged and memoized using a hash of the directory contents. As a result, every image should be made completely reproducible if possible. This means using fixed tags and fixed versions whenever feasible. -Notes that images should also be made architecture-independent if possible. The -build scripts will handling loading the appropriate architecture onto the -machine and tagging it with the single canonical tag. +Note that images should also be made architecture-independent if possible. The +build scripts will handle loading the appropriate architecture onto the machine +and tagging it with the single canonical tag. Add a `load-<image>` dependency in the Makefile if the image is required for a particular set of tests. This target will pull the tag from the image repository @@ -1,6 +1,6 @@ groups: # We define three basic groups: generated (all generated files), - # exteranl (all files outside the repository), and internal (all + # external (all files outside the repository), and internal (all # files within the local repository). We can't enforce many style # checks on generated and external code, so enable those cases # selectively for analyzers below. @@ -42,6 +42,10 @@ global: # Generated gRPC code is not compliant either. - "error strings should not be capitalized" - "grpc.Errorf is deprecated" + # Generated proto code does not always follow capitalization conventions. + - "(field|method|struct|type) .* should be .*" + # Generated proto code sometimes duplicates imports with aliases. + - "duplicate import" internal: suppress: # We use ALL_CAPS for system definitions, @@ -60,21 +64,10 @@ global: - pkg/abi/linux/fuse.go:25 - pkg/abi/linux/socket.go:113 - pkg/abi/linux/tty.go:73 - - pkg/bpf/decoder.go:112 - pkg/cpuid/cpuid_x86.go:675 - - pkg/eventchannel/event.go:193 - - pkg/eventchannel/event.go:27 - - pkg/eventchannel/event_test.go:22 - - pkg/eventchannel/rate.go:19 - pkg/gohacks/gohacks_unsafe.go:33 - pkg/log/json.go:30 - pkg/log/log.go:359 - - pkg/merkletree/merkletree.go:230 - - pkg/merkletree/merkletree.go:243 - - pkg/merkletree/merkletree.go:249 - - pkg/merkletree/merkletree.go:266 - - pkg/merkletree/merkletree.go:355 - - pkg/merkletree/merkletree.go:369 - pkg/metric/metric_test.go:20 - pkg/p9/p9test/client_test.go:687 - pkg/p9/transport_test.go:196 @@ -117,14 +110,6 @@ global: - pkg/sentry/fs/gofer/attr.go:120 - pkg/sentry/fs/gofer/fifo.go:33 - pkg/sentry/fs/gofer/inode.go:410 - - pkg/sentry/fsimpl/devpts/devpts.go:110 - - pkg/sentry/fsimpl/devpts/devpts.go:246 - - pkg/sentry/fsimpl/devpts/devpts.go:50 - - pkg/sentry/fsimpl/devpts/master.go:110 - - pkg/sentry/fsimpl/devpts/master.go:55 - - pkg/sentry/fsimpl/devpts/replica.go:113 - - pkg/sentry/fsimpl/devpts/replica.go:57 - - pkg/sentry/fsimpl/devtmpfs/devtmpfs.go:54 - pkg/sentry/fsimpl/ext/disklayout/superblock_64.go:97 - pkg/sentry/fsimpl/ext/disklayout/superblock_old.go:92 - pkg/sentry/fsimpl/ext/disklayout/block_group_32.go:44 @@ -132,219 +117,13 @@ global: - pkg/sentry/fsimpl/ext/disklayout/inode_old.go:93 - pkg/sentry/fsimpl/ext/disklayout/superblock_32.go:66 - pkg/sentry/fsimpl/ext/disklayout/block_group_64.go:53 - - pkg/sentry/fsimpl/eventfd/eventfd.go:268 - - pkg/sentry/fsimpl/ext/directory.go:163 - - pkg/sentry/fsimpl/ext/directory.go:164 - - pkg/sentry/fsimpl/ext/extent_file.go:142 - - pkg/sentry/fsimpl/ext/extent_file.go:143 - - pkg/sentry/fsimpl/ext/ext.go:105 - - pkg/sentry/fsimpl/ext/filesystem.go:287 - - pkg/sentry/fsimpl/ext/regular_file.go:153 - - pkg/sentry/fsimpl/ext/symlink.go:113 - - pkg/sentry/fsimpl/fuse/connection_control.go:194 - - pkg/sentry/fsimpl/fuse/dev.go:387 - - pkg/sentry/fsimpl/fuse/dev_test.go:318 - - pkg/sentry/fsimpl/fuse/fusefs.go:102 - - pkg/sentry/fsimpl/fuse/read_write.go:129 - pkg/sentry/fsimpl/fuse/request_response.go:71 - - pkg/sentry/fsimpl/gofer/directory.go:135 - - pkg/sentry/fsimpl/gofer/filesystem.go:679 - - pkg/sentry/fsimpl/gofer/gofer.go:1694 - - pkg/sentry/fsimpl/gofer/gofer.go:276 - - pkg/sentry/fsimpl/gofer/regular_file.go:81 - - pkg/sentry/fsimpl/gofer/special_file.go:141 - - pkg/sentry/fsimpl/host/host.go:184 - - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:50 - - pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go:90 - - pkg/sentry/fsimpl/kernfs/fd_impl_util.go:273 - - pkg/sentry/fsimpl/kernfs/filesystem.go:247 - - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:320 - - pkg/sentry/fsimpl/kernfs/inode_impl_util.go:497 - - pkg/sentry/fsimpl/kernfs/synthetic_directory.go:52 - - pkg/sentry/fsimpl/overlay/directory.go:119 - - pkg/sentry/fsimpl/overlay/filesystem.go:527 - - pkg/sentry/fsimpl/overlay/non_directory.go:152 - - pkg/sentry/fsimpl/overlay/overlay.go:115 - - pkg/sentry/fsimpl/overlay/overlay.go:719 - - pkg/sentry/fsimpl/pipefs/pipefs.go:74 - - pkg/sentry/fsimpl/proc/filesystem.go:52 - - pkg/sentry/fsimpl/proc/filesystem.go:81 - - pkg/sentry/fsimpl/proc/subtasks.go:126 - - pkg/sentry/fsimpl/proc/subtasks.go:189 - - pkg/sentry/fsimpl/proc/task_fds.go:168 - - pkg/sentry/fsimpl/proc/task_fds.go:228 - - pkg/sentry/fsimpl/proc/task_fds.go:301 - - pkg/sentry/fsimpl/proc/task_fds.go:318 - - pkg/sentry/fsimpl/proc/task_fds.go:67 - - pkg/sentry/fsimpl/proc/task_files.go:112 - - pkg/sentry/fsimpl/proc/task_files.go:158 - - pkg/sentry/fsimpl/proc/task_files.go:259 - - pkg/sentry/fsimpl/proc/task_files.go:285 - - pkg/sentry/fsimpl/proc/task_files.go:305 - - pkg/sentry/fsimpl/proc/task_files.go:384 - - pkg/sentry/fsimpl/proc/task_files.go:403 - - pkg/sentry/fsimpl/proc/task_files.go:428 - - pkg/sentry/fsimpl/proc/task_files.go:691 - - pkg/sentry/fsimpl/proc/task_files.go:770 - - pkg/sentry/fsimpl/proc/task_files.go:797 - - pkg/sentry/fsimpl/proc/task_files.go:828 - - pkg/sentry/fsimpl/proc/task_files.go:879 - - pkg/sentry/fsimpl/proc/task_files.go:910 - - pkg/sentry/fsimpl/proc/task_files.go:961 - - pkg/sentry/fsimpl/proc/task.go:127 - - pkg/sentry/fsimpl/proc/task.go:193 - - pkg/sentry/fsimpl/proc/task_net.go:134 - - pkg/sentry/fsimpl/proc/task_net.go:475 - - pkg/sentry/fsimpl/proc/task_net.go:491 - - pkg/sentry/fsimpl/proc/task_net.go:508 - - pkg/sentry/fsimpl/proc/task_net.go:665 - - pkg/sentry/fsimpl/proc/task_net.go:715 - - pkg/sentry/fsimpl/proc/task_net.go:779 - - pkg/sentry/fsimpl/proc/tasks_files.go:113 - - pkg/sentry/fsimpl/proc/tasks_files.go:388 - - pkg/sentry/fsimpl/proc/tasks.go:232 - - pkg/sentry/fsimpl/proc/tasks_sys.go:145 - - pkg/sentry/fsimpl/proc/tasks_sys.go:181 - - pkg/sentry/fsimpl/proc/tasks_sys.go:239 - - pkg/sentry/fsimpl/proc/tasks_sys.go:291 - - pkg/sentry/fsimpl/proc/tasks_sys.go:375 - - pkg/sentry/fsimpl/signalfd/signalfd.go:124 - pkg/sentry/fsimpl/signalfd/signalfd.go:15 - - pkg/sentry/fsimpl/signalfd/signalfd.go:126 - - pkg/sentry/fsimpl/sockfs/sockfs.go:36 - - pkg/sentry/fsimpl/sockfs/sockfs.go:79 - - pkg/sentry/fsimpl/sys/kcov.go:49 - - pkg/sentry/fsimpl/sys/kcov.go:99 - - pkg/sentry/fsimpl/sys/sys.go:118 - - pkg/sentry/fsimpl/sys/sys.go:56 - - pkg/sentry/fsimpl/testutil/testutil.go:257 - - pkg/sentry/fsimpl/testutil/testutil.go:260 - - pkg/sentry/fsimpl/timerfd/timerfd.go:87 - - pkg/sentry/fsimpl/tmpfs/directory.go:112 - - pkg/sentry/fsimpl/tmpfs/filesystem.go:195 - - pkg/sentry/fsimpl/tmpfs/regular_file.go:226 - - pkg/sentry/fsimpl/tmpfs/regular_file.go:346 - - pkg/sentry/fsimpl/tmpfs/tmpfs.go:103 - - pkg/sentry/fsimpl/tmpfs/tmpfs.go:733 - - pkg/sentry/fsimpl/verity/filesystem.go:490 - - pkg/sentry/fsimpl/verity/verity.go:156 - - pkg/sentry/fsimpl/verity/verity.go:629 - - pkg/sentry/fsimpl/verity/verity.go:672 - - pkg/sentry/fs/mount.go:162 - - pkg/sentry/fs/mount.go:256 - - pkg/sentry/fs/mount_overlay.go:144 - - pkg/sentry/fs/mounts.go:432 - - pkg/sentry/fs/proc/exec_args.go:104 - - pkg/sentry/fs/proc/exec_args.go:73 - - pkg/sentry/fs/proc/fds.go:269 - - pkg/sentry/fs/proc/loadavg.go:33 - - pkg/sentry/fs/proc/meminfo.go:39 - - pkg/sentry/fs/proc/mounts.go:193 - - pkg/sentry/fs/proc/mounts.go:84 - - pkg/sentry/fs/proc/net.go:125 - - pkg/sentry/fs/proc/proc.go:146 - - pkg/sentry/fs/proc/proc.go:204 - - pkg/sentry/fs/proc/seqfile/seqfile.go:210 - - pkg/sentry/fs/proc/sys.go:146 - - pkg/sentry/fs/proc/sys.go:43 - - pkg/sentry/fs/proc/sys_net.go:113 - - pkg/sentry/fs/proc/sys_net.go:205 - - pkg/sentry/fs/proc/sys_net.go:233 - - pkg/sentry/fs/proc/sys_net.go:307 - - pkg/sentry/fs/proc/sys_net.go:335 - - pkg/sentry/fs/proc/sys_net.go:446 - - pkg/sentry/fs/proc/sys_net.go:456 - - pkg/sentry/fs/proc/sys_net.go:89 - - pkg/sentry/fs/proc/task.go:170 - - pkg/sentry/fs/proc/task.go:322 - - pkg/sentry/fs/proc/task.go:427 - - pkg/sentry/fs/proc/task.go:467 - - pkg/sentry/fs/proc/task.go:500 - - pkg/sentry/fs/proc/task.go:784 - - pkg/sentry/fs/proc/task.go:839 - - pkg/sentry/fs/proc/task.go:920 - - pkg/sentry/fs/proc/uid_gid_map.go:108 - - pkg/sentry/fs/proc/uid_gid_map.go:79 - - pkg/sentry/fs/proc/uptime.go:75 - - pkg/sentry/fs/ramfs/dir.go:447 - - pkg/sentry/fs/tmpfs/inode_file.go:436 - - pkg/sentry/fs/tmpfs/inode_file.go:537 - - pkg/sentry/fs/tty/dir.go:313 - - pkg/sentry/fs/tty/master.go:131 - - pkg/sentry/fs/tty/master.go:91 - - pkg/sentry/fs/tty/replica.go:116 - - pkg/sentry/fs/tty/replica.go:88 - - pkg/sentry/kernel/auth/id_map.go:269 - - pkg/sentry/kernel/fasync/fasync.go:67 - - pkg/sentry/kernel/kcov.go:209 - - pkg/sentry/kernel/kcov.go:223 - - pkg/sentry/kernel/kernel.go:343 - - pkg/sentry/kernel/kernel.go:368 - - pkg/sentry/kernel/pipe/node_test.go:112 - - pkg/sentry/kernel/pipe/node_test.go:119 - - pkg/sentry/kernel/pipe/node_test.go:130 - - pkg/sentry/kernel/pipe/node_test.go:137 - - pkg/sentry/kernel/pipe/node_test.go:149 - - pkg/sentry/kernel/pipe/node_test.go:150 - - pkg/sentry/kernel/pipe/node_test.go:158 - - pkg/sentry/kernel/pipe/node_test.go:174 - - pkg/sentry/kernel/pipe/node_test.go:180 - - pkg/sentry/kernel/pipe/node_test.go:193 - - pkg/sentry/kernel/pipe/node_test.go:202 - - pkg/sentry/kernel/pipe/node_test.go:205 - - pkg/sentry/kernel/pipe/node_test.go:216 - - pkg/sentry/kernel/pipe/node_test.go:219 - - pkg/sentry/kernel/pipe/node_test.go:271 - - pkg/sentry/kernel/pipe/node_test.go:290 - - pkg/sentry/kernel/pipe/pipe_test.go:93 - - pkg/sentry/kernel/pipe/reader_writer.go:65 - - pkg/sentry/kernel/posixtimer.go:157 - - pkg/sentry/kernel/ptrace.go:218 - - pkg/sentry/kernel/semaphore/semaphore.go:323 - - pkg/sentry/kernel/sessions.go:123 - - pkg/sentry/kernel/sessions.go:508 - - pkg/sentry/kernel/signal_handlers.go:57 - - pkg/sentry/kernel/task_context.go:72 - - pkg/sentry/kernel/task_exit.go:67 - - pkg/sentry/kernel/task_sched.go:255 - - pkg/sentry/kernel/task_sched.go:280 - - pkg/sentry/kernel/task_sched.go:323 - - pkg/sentry/kernel/task_stop.go:192 - - pkg/sentry/kernel/thread_group.go:530 - - pkg/sentry/kernel/timekeeper.go:316 - - pkg/sentry/kernel/vdso.go:106 - - pkg/sentry/kernel/vdso.go:118 - pkg/sentry/memmap/memmap.go:103 - pkg/sentry/memmap/memmap.go:163 - - pkg/sentry/mm/address_space.go:42 - - pkg/sentry/mm/address_space.go:42 - pkg/sentry/mm/aio_context.go:208 - - pkg/sentry/mm/aio_context.go:288 - pkg/sentry/mm/pma.go:683 - - pkg/sentry/mm/special_mappable.go:80 - - pkg/sentry/syscalls/linux/sys_sem.go:62 - - pkg/sentry/syscalls/linux/sys_time.go:189 - pkg/sentry/usage/cpu.go:42 - - pkg/sentry/vfs/anonfs.go:302 - - pkg/sentry/vfs/anonfs.go:99 - - pkg/sentry/vfs/dentry.go:214 - - pkg/sentry/vfs/epoll.go:168 - - pkg/sentry/vfs/epoll.go:314 - - pkg/sentry/vfs/file_description.go:549 - - pkg/sentry/vfs/file_description_impl_util.go:304 - - pkg/sentry/vfs/file_description_impl_util.go:412 - - pkg/sentry/vfs/filesystem.go:76 - - pkg/sentry/vfs/lock.go:15 - - pkg/sentry/vfs/lock.go:47 - - pkg/sentry/vfs/memxattr/xattr.go:37 - - pkg/sentry/vfs/mount.go:510 - - pkg/sentry/vfs/mount.go:667 - - pkg/sentry/vfs/mount_test.go:106 - - pkg/sentry/vfs/mount_test.go:160 - - pkg/sentry/vfs/mount_test.go:215 - - pkg/sentry/vfs/mount_unsafe.go:153 - - pkg/sentry/vfs/resolving_path.go:228 - - pkg/sentry/vfs/vfs.go:897 - pkg/shim/runsc/runsc.go:16 - pkg/shim/runsc/utils.go:16 - pkg/shim/v1/proc/deleted_state.go:16 @@ -376,27 +155,10 @@ global: - pkg/state/tests/integer_test.go:28 - pkg/sync/rwmutex_test.go:105 - pkg/syserr/host_linux.go:35 - - pkg/unet/unet_test.go:634 - - pkg/unet/unet_test.go:662 - - pkg/unet/unet_test.go:703 - - pkg/unet/unet_test.go:98 - pkg/usermem/addr.go:34 - pkg/usermem/usermem.go:171 - pkg/usermem/usermem.go:170 - - runsc/boot/compat.go:22 - runsc/boot/compat.go:56 - - runsc/boot/loader.go:1115 - - runsc/boot/loader.go:1120 - - runsc/cmd/checkpoint.go:151 - - runsc/config/flags.go:32 - - runsc/container/container.go:641 - - runsc/container/container.go:988 - - runsc/specutils/specutils.go:172 - - runsc/specutils/specutils.go:428 - - runsc/specutils/specutils.go:436 - - runsc/specutils/specutils.go:442 - - runsc/specutils/specutils.go:447 - - runsc/specutils/specutils.go:454 - test/cmd/test_app/fds.go:171 - test/iptables/filter_output.go:251 - test/packetimpact/testbench/connections.go:77 @@ -412,13 +174,6 @@ global: - tools/checkescape/test1/test1.go:64 - tools/checkescape/test1/test1.go:80 - tools/checkescape/test1/test1.go:94 - - tools/go_generics/imports.go:51 - - tools/go_generics/imports.go:75 - - tools/go_marshal/gomarshal/generator.go:177 - - tools/go_marshal/gomarshal/generator.go:81 - - tools/go_marshal/gomarshal/generator.go:85 - - tools/go_marshal/test/escape/escape.go:15 - - tools/go_marshal/test/test.go:164 analyzers: asmdecl: external: # Enabled. diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index 069d0395d..6d1e65cb1 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error { case B: w.WriteString("1") default: - return fmt.Errorf("Invalid BPF LD size: %v", inst) + return fmt.Errorf("invalid BPF LD size: %v", inst) } return nil } diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 18457d287..e0a9e56c5 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -147,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 { // meatadata. type VerityDescriptor struct { Name string + FileSize int64 Mode uint32 UID uint32 GID uint32 @@ -154,7 +155,7 @@ type VerityDescriptor struct { } func (d *VerityDescriptor) String() string { - return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash) + return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash) } // verify generates a hash from d, and compares it with expected. @@ -289,6 +290,7 @@ func Generate(params *GenerateParams) ([]byte, error) { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, @@ -342,6 +344,7 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, @@ -401,10 +404,11 @@ func Verify(params *VerifyParams) (int64, error) { } } descriptor := VerityDescriptor{ - Name: params.Name, - Mode: params.Mode, - UID: params.UID, - GID: params.GID, + Name: params.Name, + FileSize: params.Size, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, } if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index 0782ca3e7..405204d94 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -185,42 +185,42 @@ func TestGenerate(t *testing.T) { { data: bytes.Repeat([]byte{0}, usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147}, }, { data: bytes.Repeat([]byte{0}, usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{14, 27, 126, 158, 9, 94, 163, 51, 243, 162, 82, 167, 183, 127, 93, 121, 221, 23, 184, 59, 104, 166, 111, 49, 161, 195, 229, 111, 121, 201, 233, 68, 10, 154, 78, 142, 154, 236, 170, 156, 110, 167, 15, 144, 155, 97, 241, 235, 202, 233, 246, 217, 138, 88, 152, 179, 238, 46, 247, 185, 125, 20, 101, 201}, + expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84}, }, { data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143}, }, { data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{55, 204, 240, 1, 224, 252, 58, 131, 251, 174, 45, 140, 107, 57, 118, 11, 18, 236, 203, 204, 19, 59, 27, 196, 3, 78, 21, 7, 22, 98, 197, 128, 17, 128, 90, 122, 54, 83, 253, 108, 156, 67, 59, 229, 236, 241, 69, 88, 99, 44, 127, 109, 204, 183, 150, 232, 187, 57, 228, 137, 209, 235, 241, 172}, + expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178}, }, { data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171}, }, { data: []byte{'a'}, hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{207, 233, 114, 94, 113, 212, 243, 160, 59, 232, 226, 77, 28, 81, 176, 61, 211, 213, 222, 190, 148, 196, 90, 166, 237, 56, 113, 148, 230, 154, 23, 105, 14, 97, 144, 211, 12, 122, 226, 207, 167, 203, 136, 193, 38, 249, 227, 187, 92, 238, 101, 97, 170, 255, 246, 209, 246, 98, 241, 150, 175, 253, 173, 206}, + expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120}, }, { data: bytes.Repeat([]byte{'a'}, usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246}, }, { data: bytes.Repeat([]byte{'a'}, usermem.PageSize), hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, - expectedHash: []byte{110, 103, 29, 250, 27, 211, 235, 119, 112, 65, 49, 156, 6, 92, 66, 105, 133, 1, 187, 172, 169, 13, 186, 34, 105, 72, 252, 131, 12, 159, 91, 188, 79, 184, 240, 227, 40, 164, 72, 193, 65, 31, 227, 153, 191, 6, 117, 42, 82, 122, 33, 255, 92, 215, 215, 249, 2, 131, 170, 134, 39, 192, 222, 33}, + expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84}, }, } @@ -275,6 +275,7 @@ func TestVerify(t *testing.T) { // fail, otherwise Verify should still succeed. modifyByte int64 modifyName bool + modifySize bool modifyMode bool modifyUID bool modifyGID bool @@ -323,6 +324,15 @@ func TestVerify(t *testing.T) { modifyName: true, shouldSucceed: false, }, + // Modified size should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifySize: true, + shouldSucceed: false, + }, // Modified mode should fail verification. { dataSize: usermem.PageSize, @@ -482,6 +492,9 @@ func TestVerify(t *testing.T) { if tc.modifyName { verifyParams.Name = defaultName + "abc" } + if tc.modifySize { + verifyParams.Size-- + } if tc.modifyMode { verifyParams.Mode = defaultMode + 1 } diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go index faf191f39..9fbc5466f 100644 --- a/pkg/refsvfs2/refs_map.go +++ b/pkg/refsvfs2/refs_map.go @@ -62,7 +62,9 @@ func Register(obj CheckedObject) { } liveObjects[obj] = struct{}{} liveObjectsMu.Unlock() - logEvent(obj, "registered") + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "registered") + } } } @@ -75,31 +77,40 @@ func Unregister(obj CheckedObject) { panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) } delete(liveObjects, obj) - logEvent(obj, "unregistered") + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "unregistered") + } } } // LogIncRef logs a reference increment. func LogIncRef(obj CheckedObject, refs int64) { - logEvent(obj, fmt.Sprintf("IncRef to %d", refs)) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("IncRef to %d", refs)) + } } // LogTryIncRef logs a successful TryIncRef call. func LogTryIncRef(obj CheckedObject, refs int64) { - logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs)) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs)) + } } // LogDecRef logs a reference decrement. func LogDecRef(obj CheckedObject, refs int64) { - logEvent(obj, fmt.Sprintf("DecRef to %d", refs)) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("DecRef to %d", refs)) + } } // logEvent logs a message for the given reference-counted object. +// +// obj.LogRefs() should be checked before calling logEvent, in order to avoid +// calling any text processing needed to evaluate msg. func logEvent(obj CheckedObject, msg string) { - if obj.LogRefs() { - log.Infof("[%s %p] %s:", obj.RefType(), obj, msg) - log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack())) - } + log.Infof("[%s %p] %s:", obj.RefType(), obj, msg) + log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack())) } // DoLeakCheck iterates through the live object map and logs a message for each diff --git a/pkg/refsvfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index 8f50b4ee6..f64b6c6ae 100644 --- a/pkg/refsvfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -13,10 +13,7 @@ // limitations under the License. // Package refs_template defines a template that can be used by reference -// counted objects. The "owner" template parameter is used in log messages to -// indicate the type of reference-counted object that exhibited a reference -// leak. As a result, structs that are embedded in other structs should not use -// this template, since it will make tracking down leaks more difficult. +// counted objects. package refs_template import ( @@ -43,9 +40,6 @@ var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. // -// Note that the number of references is actually refCount + 1 so that a default -// zero-value Refs object contains one reference. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -58,6 +52,13 @@ type Refs struct { refCount int64 } +// InitRefs initializes r with one reference and, if enabled, activates leak +// checking. +func (r *Refs) InitRefs() { + atomic.StoreInt64(&r.refCount, 1) + refsvfs2.Register(r) +} + // RefType implements refsvfs2.CheckedObject.RefType. func (r *Refs) RefType() string { return fmt.Sprintf("%T", obj)[1:] @@ -81,8 +82,7 @@ func (r *Refs) EnableLeakCheck() { // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. func (r *Refs) ReadRefs() int64 { - // Account for the internal -1 offset on refcounts. - return atomic.LoadInt64(&r.refCount) + 1 + return atomic.LoadInt64(&r.refCount) } // IncRef implements refs.RefCounter.IncRef. @@ -90,8 +90,10 @@ func (r *Refs) ReadRefs() int64 { //go:nosplit func (r *Refs) IncRef() { v := atomic.AddInt64(&r.refCount, 1) - refsvfs2.LogIncRef(r, v+1) - if v <= 0 { + if enableLogging { + refsvfs2.LogIncRef(r, v) + } + if v <= 1 { panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) } } @@ -105,7 +107,7 @@ func (r *Refs) IncRef() { //go:nosplit func (r *Refs) TryIncRef() bool { const speculativeRef = 1 << 32 - if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 { + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) == 0 { // This object has already been freed. atomic.AddInt64(&r.refCount, -speculativeRef) return false @@ -113,7 +115,9 @@ func (r *Refs) TryIncRef() bool { // Turn into a real reference. v := atomic.AddInt64(&r.refCount, -speculativeRef+1) - refsvfs2.LogTryIncRef(r, v+1) + if enableLogging { + refsvfs2.LogTryIncRef(r, v) + } return true } @@ -131,12 +135,14 @@ func (r *Refs) TryIncRef() bool { //go:nosplit func (r *Refs) DecRef(destroy func()) { v := atomic.AddInt64(&r.refCount, -1) - refsvfs2.LogDecRef(r, v+1) + if enableLogging { + refsvfs2.LogDecRef(r, v+1) + } switch { - case v < -1: + case v < 0: panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) - case v == -1: + case v == 0: refsvfs2.Unregister(r) // Call the destructor. if destroy != nil { diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 3c66dc3c2..6b3627813 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // maxFilenameLen is the maximum length of a filename. This is dictated by 9P's @@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, } // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Wrap the fileOps with our Fifo. iops := &fifo{ diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index e555672ad..52061175f 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { } // GetFile implements fs.InodeOperations.GetFile. -func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true - return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil + return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil } // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index fc0498f17..d6c65301c 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Continue. seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil @@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // Write to that memory as usual. seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 998b697ca..cf4ed5de0 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -336,7 +336,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 346cca558..d8c237753 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -110,7 +110,7 @@ func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.Vir } root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - root.EnableLeakCheck() + root.InitRefs() var rootD kernfs.Dentry rootD.InitRoot(&fs.Filesystem, root) diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 5986133e9..95c475a65 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F readPayload.MarshalUnsafe(outBuf[outHdrLen:]) outIOseq := usermem.BytesIOSequence(outBuf) - n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) if err != nil { t.Fatalf("Write failed :%v", err) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 6de416da0..cd0eb56e5 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -219,16 +219,12 @@ func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt } fuseFD := device.Impl().(*DeviceFD) - fs := &filesystem{ devMinor: devMinor, opts: opts, conn: conn, } - - fs.VFSFilesystem().IncRef() fuseFD.fs = fs - return fs, nil } @@ -288,7 +284,7 @@ func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode i := &inode{fs: fs, nodeID: 1} i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - i.EnableLeakCheck() + i.InitRefs() var d kernfs.Dentry d.InitRoot(&fs.Filesystem, i) @@ -301,7 +297,7 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - i.EnableLeakCheck() + i.InitRefs() return i } diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index e1d9e3365..b2f4276b8 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -72,6 +72,7 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque if err != nil { return nil, nil, err } + fs.VFSFilesystem().Init(vfsObj, nil, fs) return fs.conn, &fuseDev.vfsfd, nil } diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index ce1b2a390..3b5927702 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -98,7 +98,9 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary - hostFD: -1, + readFD: -1, + writeFD: -1, + mmapFD: -1, nlink: uint32(2), } refsvfs2.Register(child) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 57a2ca43c..7ab298019 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -30,7 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -372,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if len(name) > maxFilenameLen { return syserror.ENAMETOOLONG } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.isDeleted() { return syserror.ENOENT } @@ -389,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if createInSyntheticDir == nil { return syserror.EPERM } @@ -408,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil && child.isSynthetic() { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // The existence of a non-synthetic dentry at name would be inconclusive // because the file it represents may have been deleted from the remote // filesystem, so we would need to make an RPC to revalidate the dentry. @@ -428,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. if err := createInRemoteDir(parent, name, &ds); err != nil { @@ -842,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode: opts.Mode, kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) return nil } @@ -1162,18 +1167,21 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { + openFD := int32(-1) + if fdobj != nil { + openFD = int32(fdobj.Release()) + } child.handleMu.Lock() if vfs.MayReadFileWithOpenFlags(opts.Flags) { child.readFile = openFile if fdobj != nil { - child.hostFD = int32(fdobj.Release()) + child.readFD = openFD + child.mmapFD = openFD } - } else if fdobj != nil { - // Can't use fdobj if it's not readable. - fdobj.Close() } if vfs.MayWriteFileWithOpenFlags(opts.Flags) { child.writeFile = openFile + child.writeFD = openFD } child.handleMu.Unlock() } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 6f82ce61b..53bcc9986 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -548,11 +548,16 @@ func (fs *filesystem) Release(ctx context.Context) { d.cache.DropAll(mf) d.dirty.RemoveAll() d.dataMu.Unlock() - // Close the host fd if one exists. - if d.hostFD >= 0 { - syscall.Close(int(d.hostFD)) - d.hostFD = -1 + // Close host FDs if they exist. + if d.readFD >= 0 { + syscall.Close(int(d.readFD)) } + if d.writeFD >= 0 && d.readFD != d.writeFD { + syscall.Close(int(d.writeFD)) + } + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 d.handleMu.Unlock() } // There can't be any specialFileFDs still using fs, since each such @@ -726,15 +731,17 @@ type dentry struct { // - If this dentry represents a regular file or directory, readFile is the // p9.File used for reads by all regularFileFDs/directoryFDs representing - // this dentry. + // this dentry, and readFD (if not -1) is a host FD equivalent to readFile + // used as a faster alternative. // // - If this dentry represents a regular file, writeFile is the p9.File - // used for writes by all regularFileFDs representing this dentry. + // used for writes by all regularFileFDs representing this dentry, and + // writeFD (if not -1) is a host FD equivalent to writeFile used as a + // faster alternative. // - // - If this dentry represents a regular file, hostFD is the host FD used - // for memory mappings and I/O (when applicable) in preference to readFile - // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must - // also be writable. If hostFD is -1, no such host FD is available. + // - If this dentry represents a regular file, mmapFD is the host FD used + // for memory mappings. If mmapFD is -1, no such FD is available, and the + // internal page cache implementation is used for memory mappings instead. // // These fields are protected by handleMu. // @@ -742,10 +749,17 @@ type dentry struct { // either p9.File transitions from closed (isNil() == true) to open // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. + // + // readFD and writeFD may or may not be the same file descriptor. mmapFD is + // always either -1 or equal to readFD; if !writeFile.isNil() (the file has + // been opened for writing), it is additionally either -1 or equal to + // writeFD. handleMu sync.RWMutex `state:"nosave"` readFile p9file `state:"nosave"` writeFile p9file `state:"nosave"` - hostFD int32 `state:"nosave"` + readFD int32 `state:"nosave"` + writeFD int32 `state:"nosave"` + mmapFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -829,7 +843,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, - hostFD: -1, + readFD: -1, + writeFD: -1, + mmapFD: -1, } d.pf.dentry = d if mask.UID { @@ -1220,7 +1236,9 @@ func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkCachingLocked(). r := atomic.AddInt64(&d.refs, 1) - refsvfs2.LogIncRef(d, r) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. @@ -1231,7 +1249,9 @@ func (d *dentry) TryIncRef() bool { return false } if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { - refsvfs2.LogTryIncRef(d, r+1) + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -1251,7 +1271,9 @@ func (d *dentry) DecRef(ctx context.Context) { // responsible for ensuring that d.checkCachingLocked will be called later. func (d *dentry) decRefNoCaching() int64 { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r < 0 { panic("gofer.dentry.decRefNoCaching() called without holding a reference") } @@ -1469,10 +1491,15 @@ func (d *dentry) destroyLocked(ctx context.Context) { } d.readFile = p9file{} d.writeFile = p9file{} - if d.hostFD >= 0 { - syscall.Close(int(d.hostFD)) - d.hostFD = -1 + if d.readFD >= 0 { + syscall.Close(int(d.readFD)) + } + if d.writeFD >= 0 && d.readFD != d.writeFD { + syscall.Close(int(d.writeFD)) } + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 d.handleMu.Unlock() if !d.file.isNil() { @@ -1584,7 +1611,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.handleMu.RUnlock() } - fdToClose := int32(-1) + var fdsToCloseArr [2]int32 + fdsToClose := fdsToCloseArr[:0] invalidateTranslations := false d.handleMu.Lock() if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { @@ -1615,56 +1643,88 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool return err } - if d.hostFD < 0 && h.fd >= 0 && openReadable && (d.writeFile.isNil() || openWritable) { - // We have no existing FD, and the new FD meets the requirements - // for d.hostFD, so start using it. - d.hostFD = h.fd - } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable { - // We have an existing read-only FD, but the file has just been - // opened for writing, so we need to start supporting writable memory - // mappings. This may race with callers of d.pf.FD() using the existing - // FD, so in most cases we need to delay closing the old FD until after - // invalidating memmap.Translations that might have observed it. - if !openReadable || h.fd < 0 { - // We don't have a read/write FD, so we have no FD that can be - // used to create writable memory mappings. Switch to using the - // internal page cache. - invalidateTranslations = true - fdToClose = d.hostFD - d.hostFD = -1 - } else if d.fs.opts.overlayfsStaleRead { - // We do have a read/write FD, but it may not be coherent with - // the existing read-only FD, so we must switch to mappings of - // the new FD in both the application and sentry. - if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) - h.close(ctx) - return err + // Update d.readFD and d.writeFD. + if h.fd >= 0 { + if openReadable && openWritable && (d.readFD < 0 || d.writeFD < 0 || d.readFD != d.writeFD) { + // Replace existing FDs with this one. + if d.readFD >= 0 { + // We already have a readable FD that may be in use by + // concurrent callers of d.pf.FD(). + if d.fs.opts.overlayfsStaleRead { + // If overlayfsStaleRead is in effect, then the new FD + // may not be coherent with the existing one, so we + // have no choice but to switch to mappings of the new + // FD in both the application and sentry. + if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) + h.close(ctx) + return err + } + fdsToClose = append(fdsToClose, d.readFD) + invalidateTranslations = true + d.readFD = h.fd + } else { + // Otherwise, we want to avoid invalidating existing + // memmap.Translations (which is expensive); instead, use + // dup3 to make the old file descriptor refer to the new + // file description, then close the new file descriptor + // (which is no longer needed). Racing callers of d.pf.FD() + // may use the old or new file description, but this + // doesn't matter since they refer to the same file, and + // any racing mappings must be read-only. + if err := syscall.Dup3(int(h.fd), int(d.readFD), syscall.O_CLOEXEC); err != nil { + oldFD := d.readFD + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldFD, err) + h.close(ctx) + return err + } + fdsToClose = append(fdsToClose, h.fd) + h.fd = d.readFD + } + } else { + d.readFD = h.fd } - invalidateTranslations = true - fdToClose = d.hostFD - d.hostFD = h.fd - } else { - // We do have a read/write FD. To avoid invalidating existing - // memmap.Translations (which is expensive), use dup3 to make - // the old file descriptor refer to the new file description, - // then close the new file descriptor (which is no longer - // needed). Racing callers of d.pf.FD() may use the old or new - // file description, but this doesn't matter since they refer - // to the same file, and any racing mappings must be read-only. - if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil { - oldHostFD := d.hostFD - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err) - h.close(ctx) - return err + if d.writeFD != h.fd && d.writeFD >= 0 { + fdsToClose = append(fdsToClose, d.writeFD) } - fdToClose = h.fd + d.writeFD = h.fd + d.mmapFD = h.fd + } else if openReadable && d.readFD < 0 { + d.readFD = h.fd + // If the file has not been opened for writing, the new FD may + // be used for read-only memory mappings. If the file was + // previously opened for reading (without an FD), then existing + // translations of the file may use the internal page cache; + // invalidate those mappings. + if d.writeFile.isNil() { + invalidateTranslations = !d.readFile.isNil() + d.mmapFD = h.fd + } + } else if openWritable && d.writeFD < 0 { + d.writeFD = h.fd + if d.readFD >= 0 { + // We have an existing read-only FD, but the file has just + // been opened for writing, so we need to start supporting + // writable memory mappings. However, the new FD is not + // readable, so we have no FD that can be used to create + // writable memory mappings. Switch to using the internal + // page cache. + invalidateTranslations = true + d.mmapFD = -1 + } + } else { + // The new FD is not useful. + fdsToClose = append(fdsToClose, h.fd) } - } else { - // h.fd is not useful. - fdToClose = h.fd + } else if openWritable && d.writeFD < 0 && d.mmapFD >= 0 { + // We have an existing read-only FD, but the file has just been + // opened for writing, so we need to start supporting writable + // memory mappings. However, we have no writable host FD. Switch to + // using the internal page cache. + invalidateTranslations = true + d.mmapFD = -1 } // Switch to new fids. @@ -1698,8 +1758,8 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.mappings.InvalidateAll(memmap.InvalidateOpts{}) d.mapsMu.Unlock() } - if fdToClose >= 0 { - syscall.Close(int(fdToClose)) + for _, fd := range fdsToClose { + syscall.Close(int(fd)) } return nil @@ -1709,7 +1769,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool func (d *dentry) readHandleLocked() handle { return handle{ file: d.readFile, - fd: d.hostFD, + fd: d.readFD, } } @@ -1717,7 +1777,7 @@ func (d *dentry) readHandleLocked() handle { func (d *dentry) writeHandleLocked() handle { return handle{ file: d.writeFile, - fd: d.hostFD, + fd: d.writeFD, } } @@ -1730,16 +1790,24 @@ func (d *dentry) syncRemoteFile(ctx context.Context) error { // Preconditions: d.handleMu must be locked. func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if d.hostFD >= 0 { + // RPC. Prefer syncing write handles over read handles, since some remote + // filesystem implementations may not sync changes made through write + // handles otherwise. + if d.writeFD >= 0 { ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(d.hostFD)) + err := syscall.Fsync(int(d.writeFD)) ctx.UninterruptibleSleepFinish(false) return err } if !d.writeFile.isNil() { return d.writeFile.fsync(ctx) } + if d.readFD >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(d.readFD)) + ctx.UninterruptibleSleepFinish(false) + return err + } if !d.readFile.isNil() { return d.readFile.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index dc8a890cb..652142ecc 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -326,7 +326,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) // dentry.readHandleLocked() without locking dentry.dataMu. rw.d.handleMu.RLock() h := rw.d.readHandleLocked() - if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off) rw.d.handleMu.RUnlock() rw.off += n @@ -446,7 +446,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // without locking dentry.dataMu. rw.d.handleMu.RLock() h := rw.d.writeHandleLocked() - if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + if (rw.d.mmapFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off) rw.off += n rw.d.dataMu.Lock() @@ -648,7 +648,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt return syserror.ENODEV } d.handleMu.RLock() - haveFD := d.hostFD >= 0 + haveFD := d.mmapFD >= 0 d.handleMu.RUnlock() if !haveFD { return syserror.ENODEV @@ -669,7 +669,7 @@ func (d *dentry) mayCachePages() bool { return true } d.handleMu.RLock() - haveFD := d.hostFD >= 0 + haveFD := d.mmapFD >= 0 d.handleMu.RUnlock() return haveFD } @@ -727,7 +727,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, // Translate implements memmap.Mappable.Translate. func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { d.handleMu.RLock() - if d.hostFD >= 0 && !d.fs.opts.forcePageCache { + if d.mmapFD >= 0 && !d.fs.opts.forcePageCache { d.handleMu.RUnlock() mr := optional if d.fs.opts.limitHostFDTranslation { @@ -881,7 +881,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { // cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file -// is available (i.e. dentry.hostFD >= 0), and that FD is used for application +// is available (i.e. dentry.mmapFD >= 0), and that FD is used for application // memory mappings (i.e. !filesystem.opts.forcePageCache). // // +stateify savable @@ -892,8 +892,8 @@ type dentryPlatformFile struct { // by dentry.dataMu. fdRefs fsutil.FrameRefSet - // If this dentry represents a regular file, and dentry.hostFD >= 0, - // hostFileMapper caches mappings of dentry.hostFD. + // If this dentry represents a regular file, and dentry.mmapFD >= 0, + // hostFileMapper caches mappings of dentry.mmapFD. hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. @@ -918,12 +918,12 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write) + return d.hostFileMapper.MapInternal(fr, int(d.mmapFD), at.Write) } // FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() defer d.handleMu.RUnlock() - return int(d.hostFD) + return int(d.mmapFD) } diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go index 17849dcc0..c90071e4e 100644 --- a/pkg/sentry/fsimpl/gofer/save_restore.go +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -139,7 +139,9 @@ func (d *dentry) beforeSave() { // afterLoad is invoked by stateify. func (d *dentry) afterLoad() { - d.hostFD = -1 + d.readFD = -1 + d.writeFD = -1 + d.mmapFD = -1 if atomic.LoadInt64(&d.refs) != -1 { refsvfs2.Register(d) } diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 39b902a3e..435a21d77 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -126,8 +126,8 @@ func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fil isTTY: isTTY, savable: savable, } + i.InitRefs() i.CachedMappable.Init(hostFD) - i.EnableLeakCheck() // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and // handle blocking behavior in the sentry. diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 8a447e29f..60acc367f 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -84,6 +84,8 @@ type ConnectedEndpoint struct { // init performs initialization required for creating new ConnectedEndpoints and // for restoring them. func (c *ConnectedEndpoint) init() *syserr.Error { + c.InitRefs() + family, err := syscall.GetsockoptInt(c.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -132,7 +134,6 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable // ConnectedEndpointRefs start off with a single reference. We need two. e.IncRef() - e.EnableLeakCheck() return &e, nil } @@ -376,8 +377,7 @@ func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr s return nil, err } - // ConnectedEndpointRefs start off with a single reference. We need two. + // e starts off with a single reference. We need two. e.IncRef() - e.EnableLeakCheck() return &e, nil } diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index f81056023..e77523f22 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { - if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return "", err +func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST + if name == "." || name == ".." { + return syserror.EEXIST } - if len(pc) > linux.NAME_MAX { - return "", syserror.ENAMETOOLONG + if len(name) > linux.NAME_MAX { + return syserror.ENAMETOOLONG } - if _, ok := parent.children[pc]; ok { - return "", syserror.EEXIST + if _, ok := parent.children[name]; ok { + return syserror.EEXIST } if parent.VFSDentry().IsDead() { - return "", syserror.ENOENT + return syserror.ENOENT } - return pc, nil + return nil } // checkDeleteLocked checks that the file represented by vfsd may be deleted. @@ -265,7 +264,7 @@ func (fs *Filesystem) Release(ctx context.Context) { // // Precondition: Filesystem.mu is held. func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) { - if d.inode.Keep() { + if d.inode.Keep() && d != d.fs.root { d.decRefLocked(ctx) } @@ -352,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -394,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -430,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } @@ -657,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDir) - switch err { + pc := rp.Component() + switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err { case nil: // Ok, continue with rename as replacement. case syserror.EEXIST: @@ -822,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index d9d76758a..eac578f25 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -568,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e return o.Unlink(ctx, name, child) } -// +stateify savable -type renameAcrossDifferentImplementationsError struct{} - -func (renameAcrossDifferentImplementationsError) Error() string { - return "rename across inodes with different implementations" -} - // Rename implements Inode.Rename. // // Precondition: Rename may only be called across two directory inodes with @@ -585,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string { // // Postcondition: reference on any replaced dentry transferred to caller. func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error { + if !o.writable { + return syserror.EPERM + } + dst, ok := dstDir.(interface{}).(*OrderedChildren) if !ok { - return renameAcrossDifferentImplementationsError{} + return syserror.EXDEV } - if !o.writable || !dst.writable { + if !dst.writable { return syserror.EPERM } + // Note: There's a potential deadlock below if concurrent calls to Rename // refer to the same src and dst directories in reverse. We avoid any // ordering issues because the caller is required to serialize concurrent @@ -662,7 +660,7 @@ var _ Inode = (*StaticDirectory)(nil) func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) - inode.EnableLeakCheck() + inode.InitRefs() inode.OrderedChildren.Init(OrderedChildrenOptions{}) links := inode.OrderedChildren.Populate(children) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index abb477c7d..c14abcff4 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -222,7 +222,9 @@ func (d *Dentry) IncRef() { // d.refs may be 0 if d.fs.mu is locked, which serializes against // d.cacheLocked(). r := atomic.AddInt64(&d.refs, 1) - refsvfs2.LogIncRef(d, r) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. @@ -233,7 +235,9 @@ func (d *Dentry) TryIncRef() bool { return false } if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { - refsvfs2.LogTryIncRef(d, r+1) + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -242,7 +246,9 @@ func (d *Dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *Dentry) DecRef(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.fs.mu.Lock() d.cacheLocked(ctx) @@ -254,7 +260,9 @@ func (d *Dentry) DecRef(ctx context.Context) { func (d *Dentry) decRefLocked(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.cacheLocked(ctx) } else if r < 0 { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 2418eec44..e63588e33 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -109,7 +109,7 @@ func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credential dir := &readonlyDir{} dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - dir.EnableLeakCheck() + dir.InitRefs() dir.IncLinks(dir.OrderedChildren.Populate(contents)) return dir } @@ -147,7 +147,7 @@ func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode dir.fs = fs dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) - dir.EnableLeakCheck() + dir.InitRefs() dir.IncLinks(dir.OrderedChildren.Populate(contents)) return dir diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 4506642ca..469f3a33d 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er if dirent.Name == "." || dirent.Name == ".." { continue } - child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 04ca85f1a..bc07d72c0 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, lookupLayerNone, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, lookupLayerNone, err } afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } rp.Advance() - return d.parent, nil + return d.parent, d.parent.topLookupLayer(), nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err + return nil, topLookupLayer, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx) if err != nil { - return nil, err + return nil, lookupLayerNone, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, topLookupLayer, err } goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) { if child, ok := parent.children[name]; ok { - return child, nil + return child, child.topLookupLayer(), nil } - child, err := fs.lookupLocked(ctx, parent, name) + child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name) if err != nil { - return nil, err + return nil, topLookupLayer, err } if parent.children == nil { parent.children = make(map[string]*dentry) @@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s parent.children[name] = child // child's refcount is initially 0, so it may be dropped after traversal. *ds = appendDentry(*ds, child) - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. -func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { +func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) { childPath := fspath.Parse(name) child := fs.newDentry() - existsOnAnyLayer := false + topLookupLayer := lookupLayerNone var lookupErr error vfsObj := fs.vfsfs.VirtualFilesystem() @@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) - if !existsOnAnyLayer { + if topLookupLayer == lookupLayerNone { // Mode, UID, GID, and (for non-directories) inode number come from // the topmost layer on which the file exists. mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if isWhiteout(&stat) { // This is a whiteout, so it "doesn't exist" on this layer, and // layers below this one are ignored. + if isUpper { + topLookupLayer = lookupLayerUpperWhiteout + } return false } isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR - if existsOnAnyLayer && !isDir { + if topLookupLayer != lookupLayerNone && !isDir { // Directories are not merged with non-directory files from lower // layers; instead, layers including and below the first // non-directory file are ignored. (This file must be a directory @@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } else { child.lowerVDs = append(child.lowerVDs, childVD) } - if !existsOnAnyLayer { - existsOnAnyLayer = true + if topLookupLayer == lookupLayerNone { + if isUpper { + topLookupLayer = lookupLayerUpper + } else { + topLookupLayer = lookupLayerLower + } child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID @@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if lookupErr != nil { child.destroyLocked(ctx) - return nil, lookupErr + return nil, topLookupLayer, lookupErr } - if !existsOnAnyLayer { + if !topLookupLayer.existsInOverlay() { child.destroyLocked(ctx) - return nil, syserror.ENOENT + return nil, topLookupLayer, syserror.ENOENT } // Device and inode numbers were copied from the topmost layer above; @@ -306,7 +314,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if err != nil { ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) child.destroyLocked(ctx) - return nil, err + return nil, topLookupLayer, err } child.devMajor = linux.UNNAMED_MAJOR child.devMinor = childDevMinor @@ -315,7 +323,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str parent.IncRef() child.parent = parent child.name = name - return child, nil + return child, topLookupLayer, nil } // lookupLayerLocked is similar to lookupLocked, but only returns information @@ -414,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool { func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -434,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, d := rp.Start().Impl().(*dentry) for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -469,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if name == "." || name == ".." { return syserror.EEXIST } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.vfsd.IsDead() { return syserror.ENOENT } @@ -495,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } + // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { @@ -797,9 +806,9 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { - fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds) + fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err } @@ -899,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * // Preconditions: // * parent.dirMu must be locked. // * parent does not already contain a child named rp.Component(). -func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { +func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { return nil, err @@ -924,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving Start: parent.upperVD, Path: fspath.Parse(childName), } - // We don't know if a whiteout exists on the upper layer; speculatively - // unlink it. - // - // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do - // know whether a whiteout exists. - var haveUpperWhiteout bool - switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err { - case nil: - haveUpperWhiteout = true - case syserror.ENOENT: - haveUpperWhiteout = false - default: - return nil, err + // Unlink the whiteout if it exists. + if haveUpperWhiteout { + if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil { + log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err) + return nil, err + } } // Create the file on the upper layer, and get an FD representing it. upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{ @@ -967,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Re-lookup to get a dentry representing the new file, which is needed for // the returned FD. - child, err := fs.getChildLocked(ctx, parent, childName, ds) + child, _, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) @@ -1047,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } @@ -1079,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.vfsd.IsDead() { return syserror.ENOENT } - replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) - if err != nil { - return err - } var ( - replaced *dentry - replacedVFSD *vfs.Dentry - whiteouts map[string]bool + replaced *dentry + replacedVFSD *vfs.Dentry + replacedLayer lookupLayer + whiteouts map[string]bool ) - if replacedLayer.existsInOverlay() { - replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil { - return err - } + replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { + return err + } + if replaced != nil { replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { @@ -1296,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Unlike UnlinkAt, we need a dentry representing the child directory being // removed in order to verify that it's empty. - child, err := fs.getChildLocked(ctx, parent, name, &ds) + child, _, err := fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } @@ -1548,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if parentMode&linux.S_ISVTX != 0 { // If the parent's sticky bit is set, we need a child dentry to get // its owner. - child, err = fs.getChildLocked(ctx, parent, name, &ds) + child, _, err = fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index f6c58f2e7..3492409b2 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -514,7 +514,9 @@ func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkDropLocked(). r := atomic.AddInt64(&d.refs, 1) - refsvfs2.LogIncRef(d, r) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. @@ -525,7 +527,9 @@ func (d *dentry) TryIncRef() bool { return false } if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { - refsvfs2.LogTryIncRef(d, r+1) + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -534,7 +538,9 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) @@ -546,7 +552,9 @@ func (d *dentry) DecRef(ctx context.Context) { func (d *dentry) decRefLocked(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.checkDropLocked(ctx) } else if r < 0 { @@ -696,6 +704,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry { return vd } +func (d *dentry) topLookupLayer() lookupLayer { + if d.upperVD.Ok() { + return lookupLayerUpper + } + return lookupLayerLower +} + func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e44b79b68..0ecb592cf 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -101,7 +101,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem) *inode { creds := auth.CredentialsFromContext(ctx) return &inode{ - pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize), ino: fs.Filesystem.NextIno(), uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index cb3c5e0fd..e001d5032 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -60,7 +60,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, // Note: credentials are overridden by taskOwnedInode. subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - subInode.EnableLeakCheck() + subInode.InitRefs() inode := &taskOwnedInode{Inode: subInode, owner: task} return inode diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index 19011b010..dc46a09bc 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -91,7 +91,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - taskInode.EnableLeakCheck() + taskInode.InitRefs() inode := &taskOwnedInode{Inode: taskInode, owner: task} diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index d268b44be..3ec4471f5 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -128,7 +128,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { }, } inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } @@ -265,7 +265,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { }, } inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index b81ea14bf..151d1f10d 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -83,7 +83,7 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns cgroupControllers: cgroupControllers, } inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) - inode.EnableLeakCheck() + inode.InitRefs() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) links := inode.OrderedChildren.Populate(contents) diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 506a2a0f0..79bc3fe88 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -160,7 +160,7 @@ func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode d := &dir{} d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) - d.EnableLeakCheck() + d.InitRefs() d.IncLinks(d.OrderedChildren.Populate(contents)) return d } diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index d772db9e9..57e7b57b0 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -32,7 +31,7 @@ type namedPipe struct { // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} + file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index 4ce859d57..85a3dfe20 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -402,7 +402,7 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth i.mtime = now // i.nlink initialized by caller i.impl = impl - i.refs.EnableLeakCheck() + i.refs.InitRefs() } // incLinksLocked increments i's link count. diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 2f6050cfd..4e8d63d51 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -276,9 +276,9 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de UID: parentStat.UID, GID: parentStat.GID, //TODO(b/156980949): Support passing other hash algorithms. - HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize(linux.FS_VERITY_HASH_ALG_SHA256)), + ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, }); err != nil && err != io.EOF { @@ -352,7 +352,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat UID: stat.UID, GID: stat.GID, //TODO(b/156980949): Support passing other hash algorithms. - HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index de92878fd..faa862c55 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -79,6 +79,27 @@ var ( verityMu sync.RWMutex ) +// HashAlgorithm is a type specifying the algorithm used to hash the file +// content. +type HashAlgorithm int + +// Currently supported hashing algorithms include SHA256 and SHA512. +const ( + SHA256 HashAlgorithm = iota + SHA512 +) + +func (alg HashAlgorithm) toLinuxHashAlg() int { + switch alg { + case SHA256: + return linux.FS_VERITY_HASH_ALG_SHA256 + case SHA512: + return linux.FS_VERITY_HASH_ALG_SHA512 + default: + return 0 + } +} + // FilesystemType implements vfs.FilesystemType. // // +stateify savable @@ -108,6 +129,10 @@ type filesystem struct { // stores the root hash of the whole file system in bytes. rootDentry *dentry + // alg is the algorithms used to hash the files in the verity file + // system. + alg HashAlgorithm + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -136,6 +161,10 @@ type InternalFilesystemOptions struct { // LowerName is the name of the filesystem wrapped by verity fs. LowerName string + // Alg is the algorithms used to hash the files in the verity file + // system. + Alg HashAlgorithm + // RootHash is the root hash of the overall verity file system. RootHash []byte @@ -194,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs := &filesystem{ creds: creds.Fork(), + alg: iopts.Alg, lowerMount: mnt, allowRuntimeEnable: iopts.AllowRuntimeEnable, } @@ -350,7 +380,9 @@ func (fs *filesystem) newDentry() *dentry { // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { r := atomic.AddInt64(&d.refs, 1) - refsvfs2.LogIncRef(d, r) + if d.LogRefs() { + refsvfs2.LogIncRef(d, r) + } } // TryIncRef implements vfs.DentryImpl.TryIncRef. @@ -361,7 +393,9 @@ func (d *dentry) TryIncRef() bool { return false } if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { - refsvfs2.LogTryIncRef(d, r+1) + if d.LogRefs() { + refsvfs2.LogTryIncRef(d, r+1) + } return true } } @@ -370,7 +404,9 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) @@ -382,7 +418,9 @@ func (d *dentry) DecRef(ctx context.Context) { func (d *dentry) decRefLocked(ctx context.Context) { r := atomic.AddInt64(&d.refs, -1) - refsvfs2.LogDecRef(d, r) + if d.LogRefs() { + refsvfs2.LogDecRef(d, r) + } if r == 0 { d.checkDropLocked(ctx) } else if r < 0 { @@ -627,7 +665,7 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, TreeReader: &merkleReader, TreeWriter: &merkleWriter, //TODO(b/156980949): Support passing other hash algorithms. - HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -873,7 +911,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of UID: fd.d.uid, GID: fd.d.gid, //TODO(b/156980949): Support passing other hash algorithms. - HashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index c647cbfd3..b2da9dd96 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -43,7 +43,7 @@ const maxDataSize = 100000 // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { k, err := testutil.Boot() if err != nil { t.Fatalf("testutil.Boot: %v", err) @@ -70,6 +70,7 @@ func newVerityRoot(t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *ke InternalData: InternalFilesystemOptions{ RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", + Alg: hashAlg, AllowRuntimeEnable: true, NoCrashOnVerificationFailure: true, }, @@ -161,280 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } +var hashAlgs = []HashAlgorithm{SHA256, SHA512} + // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Ensure that the corresponding Merkle tree file is created. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) - } - - // Ensure the root merkle tree file is created. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + rootMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } } } // TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity // file succeeds after enabling verity for it. func TestPReadUnmodifiedFileSucceeds(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.PRead: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } // TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity // file succeeds after enabling verity for it. func TestReadUnmodifiedFileSucceeds(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - buf := make([]byte, size) - n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.Read: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Ensure reopening the verity enabled file succeeds. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != nil { - t.Errorf("reopen enabled file failed: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } } } // TestPReadModifiedFileFails ensures that read from a modified verity file // fails. func TestPReadModifiedFileFails(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded, expected failure") + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded, expected failure") + } } } // TestReadModifiedFileFails ensures that read from a modified verity file // fails. func TestReadModifiedFileFails(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.Read succeeded, expected failure") + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") + } } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD - - lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerMerkleVD, - Start: lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the Merkle tree file. - stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - merkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - // Confirm that read from a file with modified Merkle tree fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) - t.Fatalf("fd.PRead succeeded with modified Merkle file") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } } } @@ -442,140 +459,146 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD - - parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: parentLowerMerkleVD, - Start: parentLowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the parent Merkle tree file. - // This parent directory contains only one child, so any random - // modification in the parent Merkle tree should cause verification - // failure when opening the child file. - stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - parentMerkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - parentLowerMerkleFD.DecRef(ctx) - - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err == nil { - t.Errorf("OpenAt file with modified parent Merkle succeeded") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } } } // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { - t.Errorf("fd.Stat: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms stat succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { + t.Errorf("fd.Stat: %v", err) + } } } // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - - lowerFD := fd.Impl().(*fileDescription).lowerFD - // Change the stat of the underlying file, and check that stat fails. - if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: uint32(linux.STATX_MODE), - Mode: 0777, - }, - }); err != nil { - t.Fatalf("lowerFD.SetStat: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + lowerFD := fd.Impl().(*fileDescription).lowerFD + // Change the stat of the underlying file, and check that stat fails. + if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: uint32(linux.STATX_MODE), + Mode: 0777, + }, + }); err != nil { + t.Fatalf("lowerFD.SetStat: %v", err) + } - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { - t.Errorf("fd.Stat succeeded when it should fail") + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { + t.Errorf("fd.Stat succeeded when it should fail") + } } } @@ -616,84 +639,86 @@ func TestOpenDeletedFileFails(t *testing.T) { } for _, tc := range testCases { t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { - vfsObj, root, ctx, err := newVerityRoot(t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } - rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD - if tc.remove { - if tc.changeFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }); err != nil { - t.Fatalf("UnlinkAt: %v", err) + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.remove { + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } } - } - if tc.changeMerkleFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }); err != nil { - t.Fatalf("UnlinkAt: %v", err) + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } } - } - } else { - newFilename := "renamed-test-file" - if tc.changeFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(newFilename), - }, &vfs.RenameOptions{}); err != nil { - t.Fatalf("RenameAt: %v", err) + } else { + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) + } } - } - if tc.changeMerkleFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + newFilename), - }, &vfs.RenameOptions{}); err != nil { - t.Fatalf("UnlinkAt: %v", err) + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } } } - } - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != syserror.EIO { - t.Errorf("got OpenAt error: %v, expected EIO", err) + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } } }) } diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD index 364a78306..db3b0d0a0 100644 --- a/pkg/sentry/hostfd/BUILD +++ b/pkg/sentry/hostfd/BUILD @@ -6,10 +6,12 @@ go_library( name = "hostfd", srcs = [ "hostfd.go", + "hostfd_linux.go", "hostfd_unsafe.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/log", "//pkg/safemem", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go new file mode 100644 index 000000000..1cabc848f --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_linux.go @@ -0,0 +1,18 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostfd + +// maxIov is the maximum permitted size of a struct iovec array. +const maxIov = 1024 // UIO_MAXIOV diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go index cd4dc67fb..694371b1c 100644 --- a/pkg/sentry/hostfd/hostfd_unsafe.go +++ b/pkg/sentry/hostfd/hostfd_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" ) @@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6 } } else { iovs := safemem.IovecsFromBlockSeq(dsts) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { @@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint } } else { iovs := safemem.IovecsFromBlockSeq(srcs) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 3476551f3..470d8bf83 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -43,7 +43,7 @@ func (f *FDTable) initNoLeakCheck() { // init initializes the table with leak checking. func (f *FDTable) init() { f.initNoLeakCheck() - f.EnableLeakCheck() + f.InitRefs() } // get gets a file entry. diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 41fb2a784..dfde4deee 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -63,7 +63,7 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext { cwd: cwd, umask: umask, } - f.EnableLeakCheck() + f.InitRefs() return &f } @@ -76,7 +76,7 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext { cwdVFS2: cwd, umask: umask, } - f.EnableLeakCheck() + f.InitRefs() return &f } @@ -137,7 +137,7 @@ func (f *FSContext) Fork() *FSContext { rootVFS2: f.rootVFS2, umask: f.umask, } - ctx.EnableLeakCheck() + ctx.InitRefs() return ctx } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index b87e40dd1..9545bb5ef 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -41,7 +41,7 @@ func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace { semaphores: semaphore.NewRegistry(userNS), shms: shm.NewRegistry(userNS), } - ns.EnableLeakCheck() + ns.InitRefs() return ns } diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ce0db5583..d6fb0fdb8 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) type sleeper struct { @@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag d := fs.NewDirent(ctx, inode, "pipe") file, err := n.GetFile(ctx, d, flags) if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) + t.Errorf("open with flags %+v failed: %v", flags, err) + return nil, err } if doneChan != nil { doneChan <- struct{}{} @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 67beb0ad6..b989e14c7 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -26,18 +26,27 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) const ( // MinimumPipeSize is a hard limit of the minimum size of a pipe. - MinimumPipeSize = 64 << 10 + // It corresponds to fs/pipe.c:pipe_min_size. + MinimumPipeSize = usermem.PageSize + + // MaximumPipeSize is a hard limit on the maximum size of a pipe. + // It corresponds to fs/pipe.c:pipe_max_size. + MaximumPipeSize = 1048576 // DefaultPipeSize is the system-wide default size of a pipe in bytes. - DefaultPipeSize = MinimumPipeSize + // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS. + DefaultPipeSize = 16 * usermem.PageSize - // MaximumPipeSize is a hard limit on the maximum size of a pipe. - MaximumPipeSize = 8 << 20 + // atomicIOBytes is the maximum number of bytes that the pipe will + // guarantee atomic reads or writes atomically. + // It corresponds to limits.h:PIPE_BUF. + atomicIOBytes = 4096 ) // Pipe is an encapsulation of a platform-independent pipe. @@ -53,12 +62,6 @@ type Pipe struct { // This value is immutable. isNamed bool - // atomicIOBytes is the maximum number of bytes that the pipe will - // guarantee atomic reads or writes atomically. - // - // This value is immutable. - atomicIOBytes int64 - // The number of active readers for this pipe. // // Access atomically. @@ -94,47 +97,34 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // -// N.B. The size and atomicIOBytes will be bounded. -func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +// N.B. The size will be bounded. +func NewPipe(isNamed bool, sizeBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } var p Pipe - initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + initPipe(&p, isNamed, sizeBytes) return &p } -func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { +func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } pipe.isNamed = isNamed pipe.max = sizeBytes - pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. -func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) +func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) { + p := NewPipe(false /* isNamed */, sizeBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { wanted := ops.left() avail := p.max - p.view.Size() if wanted > avail { - if wanted <= p.atomicIOBytes { + if wanted <= atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index fe97e9800..3dd739080 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -26,7 +26,7 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) { const capacity = MinimumPipeSize ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) + r, w := NewConnectedPipe(ctx, capacity) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { const atomicIOBytes = 2 ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) + r, w := NewConnectedPipe(ctx, atomicIOBytes) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { } } if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) + t.Errorf("Readv: got unexpected error %v", err) + return } } }() diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index d96bf253b..7b23cbe86 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -54,9 +54,9 @@ type VFSPipe struct { } // NewVFSPipe returns an initialized VFSPipe. -func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { +func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { var vp VFSPipe - initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes) + initPipe(&vp.pipe, isNamed, sizeBytes) return &vp } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1145faf13..1abfe2201 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { + if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } _, err := word.CopyOut(t, data) @@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: word := t.Arch().Native(uintptr(data)) - _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) + _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c39ecfb8f..b99c0bffa 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -103,6 +103,7 @@ type waiter struct { waiterEntry // value represents how much resource the waiter needs to wake up. + // The value is either 0 or negative. value int16 ch chan struct{} } @@ -423,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) { return sem.pid, nil } +func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // The calling process must have read permission on the semaphore set. + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return 0, syserror.EACCES + } + + sem := s.findSem(num) + if sem == nil { + return 0, syserror.ERANGE + } + var cnt uint16 + for w := sem.waiters.Front(); w != nil; w = w.Next() { + if pred(w) { + cnt++ + } + } + return cnt, nil +} + +// CountZeroWaiters returns number of waiters waiting for the sem's value to increase. +func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value == 0 + }) +} + +// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero. +func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value < 0 + }) +} + // ExecuteOps attempts to execute a list of operations to the set. It only // succeeds when all operations can be applied. No changes are made if it fails. // @@ -575,11 +612,18 @@ func (s *Set) destroy() { } } +func abs(val int16) int16 { + if val < 0 { + return -val + } + return val +} + // wakeWaiters goes over all waiters and checks which of them can be notified. func (s *sem) wakeWaiters() { // Note that this will release all waiters waiting for 0 too. for w := s.waiters.Front(); w != nil; { - if s.value < w.value { + if s.value < abs(w.value) { // Still blocked, skip it. w = w.Next() continue diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index df5c8421b..0cd9e2533 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -295,7 +295,7 @@ func (tg *ThreadGroup) createSession() error { id: SessionID(id), leader: tg, } - s.EnableLeakCheck() + s.InitRefs() // Create a new ProcessGroup, belonging to that Session. // This also has a single reference (assigned below). @@ -309,7 +309,7 @@ func (tg *ThreadGroup) createSession() error { session: s, ancestors: 0, } - pg.refs.EnableLeakCheck() + pg.refs.InitRefs() // Tie them and return the result. s.processGroups.PushBack(pg) @@ -395,7 +395,7 @@ func (tg *ThreadGroup) CreateProcessGroup() error { originator: tg, session: tg.processGroup.session, } - pg.refs.EnableLeakCheck() + pg.refs.InitRefs() if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session { pg.ancestors++ @@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session { // // If this group isn't visible in this namespace, zero will be returned. It is // the callers responsibility to check that before using this function. -func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sids[s] +func (ns *PIDNamespace) IDOfSession(s *Session) SessionID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sids[s] } // SessionWithID returns the Session with the given ID in the PID namespace ns, // or nil if that given ID is not defined in this namespace. // // A reference is not taken on the session. -func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sessions[id] +func (ns *PIDNamespace) SessionWithID(id SessionID) *Session { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sessions[id] } // ProcessGroup returns the ThreadGroup's ProcessGroup. @@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup { // IDOfProcessGroup returns the process group assigned to pg in PID namespace ns. // // The same constraints apply as IDOfSession. -func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.pgids[pg] +func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.pgids[pg] } // ProcessGroupWithID returns the ProcessGroup with the given ID in the PID // namespace ns, or nil if that given ID is not defined in this namespace. // // A reference is not taken on the process group. -func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.processGroups[id] +func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.processGroups[id] } diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index ebbebf46b..92d60ba78 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -251,7 +251,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi creatorPID: pid, changeTime: ktime.NowFromContext(ctx), } - shm.EnableLeakCheck() + shm.InitRefs() // Find the next available ID. for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ { diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 682080c14..527344162 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } if opts.ChildSetTID { ctid := nt.ThreadID() - ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) + ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index ce134bf54..94dabbcd8 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,7 +18,8 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp }, nil } -// copyContext implements marshal.CopyContext. It wraps a task to allow copying -// memory to and from the task memory with custom usermem.IOOpts. -type copyContext struct { - *Task +type taskCopyContext struct { + ctx context.Context + t *Task opts usermem.IOOpts } -// AsCopyContext wraps the task and returns it as CopyContext. -func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { - return ©Context{t, opts} +// CopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. +func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext { + return &taskCopyContext{ + ctx: ctx, + t: t, + opts: opts, + } +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte { + if ctxTask, ok := cc.ctx.(*Task); ok { + return ctxTask.CopyScratchBuffer(size) + } + return make([]byte, size) +} + +func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) { + cc.t.mu.Lock() + tmm := cc.t.MemoryManager() + cc.t.mu.Unlock() + if !tmm.IncUsers() { + return nil, syserror.EFAULT + } + return tmm, nil +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyIn(cc.ctx, addr, dst, cc.opts) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyOut(cc.ctx, addr, src, cc.opts) +} + +type ownTaskCopyContext struct { + t *Task + opts usermem.IOOpts +} + +// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. The returned CopyContext may only be used by t's task +// goroutine. +// +// Since t already implements marshal.CopyContext, this is only needed to +// override the usermem.IOOpts used for the copy. +func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext { + return &ownTaskCopyContext{ + t: t, + opts: opts, + } } -// CopyInString copies a string in from the task's memory. -func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { - return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte { + return cc.t.CopyScratchBuffer(size) } -// CopyInBytes copies task memory into dst from an IO context. -func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { - return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts) } -// CopyOutBytes copies src into task memoryfrom an IO context. -func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { - return t.MemoryManager().CopyOut(t, addr, src, t.opts) +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts) } diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 9bc452e67..9e5c2d26f 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error { } if old != v.seq { - return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq) + return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq) } v.seq = next diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 7bf48cb2c..4c8cd38ed 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -252,7 +252,7 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { return nil, err } m := aioMappable{mfp: mfp, fr: fr} - m.EnableLeakCheck() + m.InitRefs() return &m, nil } diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 2dbe5b751..48d8b6a2b 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -44,7 +44,7 @@ type SpecialMappable struct { // Preconditions: fr.Length() != 0. func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} - m.EnableLeakCheck() + m.InitRefs() return &m } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0a54dd30d..acad4c793 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) { c.runData.requestInterruptWindow = 0 } +// bluepillSigBus is reponsible for injecting NMI to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_NMI, 0); errno != 0 { + throw("NMI injection failed") + } +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 58f3d6fdd..965ad66b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -27,15 +27,20 @@ var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL - // vcpuSErr is the event of system error. - vcpuSErr = kvmVcpuEvents{ + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce = kvmVcpuEvents{ exception: exception{ sErrPending: 1, - sErrHasEsr: 0, - pad: [6]uint8{0, 0, 0, 0, 0, 0}, - sErrEsr: 1, }, - rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } + + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, } ) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index b35c930e2..9433d4da5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { + uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + throw("sErr injection failed") + } +} + +// bluepillSigBus is reponsible for injecting sError to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { throw("sErr injection failed") } } diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index eb05950cd..75085ac6a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( // escapes: no. - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_NMI, 0); errno != 0 { - throw("NMI injection failed") - } + bluepillSigBus(c) continue // Rerun vCPU. default: throw("run failed") diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index dd45ad10b..5979aef97 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { // NewAddressSpace returns a new pagetable root. func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. - pageTables := pagetables.New(newAllocator()) - k.machine.mapUpperHalf(pageTables) + pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 5831b9345..b060d9544 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -151,6 +151,9 @@ const ( _ESR_SEGV_PEMERR_L1 = 0xd _ESR_SEGV_PEMERR_L2 = 0xe _ESR_SEGV_PEMERR_L3 = 0xf + + // Custom ISS field definitions for system error. + _ESR_ELx_SERR_NMI = 0x1 ) // Arm64: MMIO base address used to dispatch hypercalls. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index f70d761fd..e2fffc99b 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -41,6 +41,9 @@ type machine struct { // slots are currently being updated, and the caller should retry. nextSlot uint32 + // upperSharedPageTables tracks the read-only shared upper of all the pagetables. + upperSharedPageTables *pagetables.PageTables + // kernel is the set of global structures. kernel ring0.Kernel @@ -199,9 +202,7 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }, m.maxVCPUs) + m.kernel.Init(m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -213,6 +214,13 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Create the upper shared pagetables and kernel(sentry) pagetables. + m.upperSharedPageTables = pagetables.New(newAllocator()) + m.mapUpperHalf(m.upperSharedPageTables) + m.upperSharedPageTables.Allocator.(*allocator).base.Drain() + m.upperSharedPageTables.MarkReadOnlyShared() + m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -226,7 +234,6 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index a8b729e62..8e03c310d 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -432,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } -var execRegions = func() (regions []region) { +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + // Map all the executible regions so that all the entry functions + // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { return } + if vr.accessType.Execute { - regions = append(regions, vr.region) + r := vr.region + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) } }) - return -}() - -func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - for _, r := range execRegions { - physical, length, ok := translateToPhysical(r.virtual) - if !ok || length < r.length { - panic("impossilbe translation") - } - pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|r.virtual), - r.length, - pagetables.MapOpts{AccessType: usermem.Execute}, - physical) - } for start, end := range m.kernel.EntryRegions() { regionLen := end - start physical, length, ok := translateToPhysical(start) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1344ed3c9..fd92c3873 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -221,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) } else if !ring0.IsCanonical(regs.Sp) { - return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info) } // Assign PCIDs. @@ -257,11 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) + case ring0.El0ErrNMI: + return c.fault(int32(syscall.SIGBUS), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef: + case ring0.El0SyncUndef: return c.fault(int32(syscall.SIGILL), info) - case ring0.El1Sync_undef: + case ring0.El1SyncUndef: *info = arch.SignalInfo{ Signo: int32(syscall.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 87a573cc4..327d48465 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -58,46 +58,55 @@ type Vector uintptr // Exception vectors. const ( - El1SyncInvalid = iota - El1IrqInvalid - El1FiqInvalid - El1ErrorInvalid + El1InvSync = iota + El1InvIrq + El1InvFiq + El1InvError + El1Sync El1Irq El1Fiq - El1Error + El1Err + El0Sync El0Irq El0Fiq - El0Error - El0Sync_invalid - El0Irq_invalid - El0Fiq_invalid - El0Error_invalid - El1Sync_da - El1Sync_ia - El1Sync_sp_pc - El1Sync_undef - El1Sync_dbg - El1Sync_inv - El0Sync_svc - El0Sync_da - El0Sync_ia - El0Sync_fpsimd_acc - El0Sync_sve_acc - El0Sync_sys - El0Sync_sp_pc - El0Sync_undef - El0Sync_dbg - El0Sync_inv + El0Err + + El0InvSync + El0InvIrq + El0InvFiq + El0InvErr + + El1SyncDa + El1SyncIa + El1SyncSpPc + El1SyncUndef + El1SyncDbg + El1SyncInv + + El0SyncSVC + El0SyncDa + El0SyncIa + El0SyncFpsimdAcc + El0SyncSveAcc + El0SyncSys + El0SyncSpPc + El0SyncUndef + El0SyncDbg + El0SyncInv + + El0ErrNMI + El0ErrBounce + _NR_INTERRUPTS ) // System call vectors. const ( - Syscall Vector = El0Sync_svc - PageFault Vector = El0Sync_da - VirtualizationException Vector = El0Error + Syscall Vector = El0SyncSVC + PageFault Vector = El0SyncDa + VirtualizationException Vector = El0ErrBounce ) // VirtualAddressBits returns the number bits available for virtual addresses. diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index e6daf24df..f9765771e 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -23,6 +23,9 @@ import ( // // This contains global state, shared by multiple CPUs. type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + KernelArchState } diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 00899273e..7a2275558 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -66,17 +66,9 @@ var ( KernelDataSegment SegmentDescriptor ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts - - // cpuEntries is array of kernelEntry for all cpus + // cpuEntries is array of kernelEntry for all cpus. cpuEntries []kernelEntry // globalIDT is our set of interrupt gates. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 508236e46..a014dcbc0 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -32,15 +32,8 @@ var ( KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts } // CPUArchState contains CPU-specific arch state. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index f9278b653..f489ad352 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -288,6 +288,10 @@ #define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) #define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) +/* ISS field definitions for system error */ +#define ESR_ELx_SERR_MASK (0x1) +#define ESR_ELx_SERR_NMI (0x1) + // LOAD_KERNEL_ADDRESS loads a kernel address. #define LOAD_KERNEL_ADDRESS(from, to) \ MOVD from, to; \ @@ -691,7 +695,7 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - EXCEPTION_WITH_ERROR(1, El0Sync_undef) + EXCEPTION_WITH_ERROR(1, El0SyncUndef) el0_dbg: B ·Shutdown(SB) @@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0 TEXT ·El0_error(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + AND $ESR_ELx_SERR_MASK, R25, R24 + CMP $ESR_ELx_SERR_NMI, R24 + BEQ el0_nmi + B el0_bounce +el0_nmi: + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + WORD $0xd538601a //MRS FAR_EL1, R26 + + MOVD R26, CPU_FAULT_ADDR(RSV_REG) + + MOVD $1, R3 + MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. + + MOVD $El0ErrNMI, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) + +el0_bounce: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0 MOVD $VirtualizationException, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 264be23d3..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -16,11 +16,9 @@ package ring0 // Init initializes a new kernel. // -// N.B. that constraints on KernelOpts must be satisfied. -// //go:nosplit -func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { - k.init(opts, maxCPUs) +func (k *Kernel) Init(maxCPUs int) { + k.init(maxCPUs) } // Halt halts execution. diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 3a9dff4cc..b55dc29b3 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -24,10 +24,7 @@ import ( ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables - +func (k *Kernel) init(maxCPUs int) { entrySize := reflect.TypeOf(kernelEntry{}).Size() var ( entries []kernelEntry diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index b294ccc7c..6cbbf001f 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,9 +25,7 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables +func (k *Kernel) init(maxCPUs int) { } // init initializes architecture-specific state. diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 45eba960d..53bc3353c 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -47,43 +47,36 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) - fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) - fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) - fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) - fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err) fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) - fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err) - fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) - fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) - fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) - fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa) + fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa) + fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc) + fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef) + fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg) + fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv) - fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) - fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) - fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) - fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) - fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) - fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC) + fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa) + fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) + fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) + fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) + fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) + fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) + fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg) + fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv) - fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) - fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) - fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) - fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) - fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) - fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) - fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) - fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) - fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) - fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI) fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 7f18ac296..bc16a1622 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -30,6 +30,10 @@ type PageTables struct { Allocator Allocator // root is the pagetable root. + // + // For same archs such as amd64, the upper of the PTEs is cloned + // from and owned by upperSharedPageTables which are shared among + // many PageTables if upperSharedPageTables is not nil. root *PTEs // rootPhysical is the cached physical address of the root. @@ -39,15 +43,52 @@ type PageTables struct { // archPageTables includes architecture-specific features. archPageTables + + // upperSharedPageTables represents a read-only shared upper + // of the Pagetable. When it is not nil, the upper is not + // allowed to be modified. + upperSharedPageTables *PageTables + + // upperStart is the start address of the upper portion that + // are shared from upperSharedPageTables + upperStart uintptr + + // readOnlyShared indicates the Pagetables are read-only and + // own the ranges that are shared with other Pagetables. + readOnlyShared bool } -// New returns new PageTables. -func New(a Allocator) *PageTables { +// NewWithUpper returns new PageTables. +// +// upperSharedPageTables are used for mapping the upper of addresses, +// starting at upperStart. These pageTables should not be touched (as +// invalidations may be incorrect) after they are passed as an +// upperSharedPageTables. Only when all dependent PageTables are gone +// may they be used. The intenteded use case is for kernel page tables, +// which are static and fixed. +// +// Precondition: upperStart must be between canonical ranges. +// Precondition: upperStart must be pgdSize aligned. +// precondition: upperSharedPageTables must be marked read-only shared. +func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables { p := new(PageTables) p.Init(a) + if upperSharedPageTables != nil { + if !upperSharedPageTables.readOnlyShared { + panic("Only read-only shared pagetables can be used as upper") + } + p.upperSharedPageTables = upperSharedPageTables + p.upperStart = upperStart + p.cloneUpperShared() + } return p } +// New returns new PageTables. +func New(a Allocator) *PageTables { + return NewWithUpper(a, nil, 0) +} + // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. @@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true } // //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } if !opts.AccessType.Any() { return p.Unmap(addr, length) } @@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // True is returned iff there was a previous mapping in the range. // -// Precondition: addr & length must be page-aligned. +// Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack // //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } w := unmapWalker{ pageTables: p, visitor: unmapVisitor{ @@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) w.iterateRange(uintptr(addr), uintptr(addr)+1) return w.visitor.physical + offset, w.visitor.opts } + +// MarkReadOnlyShared marks the pagetables read-only and can be shared. +// +// It is usually used on the pagetables that are used as the upper +func (p *PageTables) MarkReadOnlyShared() { + p.readOnlyShared = true +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 520161755..a4e416af7 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -24,14 +24,6 @@ import ( // archPageTables is architecture-specific data. type archPageTables struct { - // root is the pagetable root for kernel space. - root *PTEs - - // rootPhysical is the cached physical address of the root. - // - // This is saved only to prevent constant translation. - rootPhysical uintptr - asid uint16 } @@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { // //go:nosplit func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { - return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset + return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset } // Bits in page table entries. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 0c153cf8c..e7ab887e5 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) { p.rootPhysical = p.Allocator.PhysicalFor(p.root) } +func pgdIndex(upperStart uintptr) uintptr { + if upperStart&(pgdSize-1) != 0 { + panic("upperStart should be pgd size aligned") + } + if upperStart >= upperBottom { + return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize + } + if upperStart < lowerTop { + return upperStart / pgdSize + } + panic("upperStart should be in canonical range") +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + start := pgdIndex(p.upperStart) + copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage]) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 5ddd10256..5392bf27a 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) { p.Allocator = allocator p.root = p.Allocator.NewPTEs() p.rootPhysical = p.Allocator.PhysicalFor(p.root) - p.archPageTables.root = p.Allocator.NewPTEs() - p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + if p.upperStart != upperBottom { + panic("upperStart should be the same as upperBottom") + } + + // nothing to do for arm. } // PTEs is a collection of entries. diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go index c261d393a..157c9a7cc 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr { func (w *Walker) iterateRangeCanonical(start, end uintptr) { pgdEntryIndex := w.pageTables.root if start >= upperBottom { - pgdEntryIndex = w.pageTables.archPageTables.root + pgdEntryIndex = w.pageTables.upperSharedPageTables.root } for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 549787955..e0976fed0 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf // marshalTarget and unmarshalTarget can be used. type targetMaker interface { // id uniquely identifies the target. - id() stack.TargetID + id() targetID - // marshal converts from a stack.Target to an ABI struct. - marshal(target stack.Target) []byte + // marshal converts from a target to an ABI struct. + marshal(target target) []byte - // unmarshal converts from the ABI matcher struct to a stack.Target. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) + // unmarshal converts from the ABI matcher struct to a target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) } -// targetMakers maps the TargetID of supported targets to the targetMaker that +// A targetID uniquely identifies a target. +type targetID struct { + // name is the target name as stored in the xt_entry_target struct. + name string + + // networkProtocol is the protocol to which the target applies. + networkProtocol tcpip.NetworkProtocolNumber + + // revision is the version of the target. + revision uint8 +} + +// target extends a stack.Target, allowing it to be used with the extension +// system. The sentry only uses targets, never stack.Targets directly. +type target interface { + stack.Target + id() targetID +} + +// targetMakers maps the targetID of supported targets to the targetMaker that // marshals and unmarshals it. It is immutable after package initialization. -var targetMakers = map[stack.TargetID]targetMaker{} +var targetMakers = map[targetID]targetMaker{} func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { - tid := stack.TargetID{ - Name: name, - NetworkProtocol: netProto, - Revision: rev, + tid := targetID{ + name: name, + networkProtocol: netProto, + revision: rev, } if _, ok := targetMakers[tid]; !ok { return 0, false @@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8 // Return the highest supported revision unless rev is higher. for _, other := range targetMakers { otherID := other.id() - if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { - rev = uint8(otherID.Revision) + if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev { + rev = uint8(otherID.revision) } } return rev, true @@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) { targetMakers[tm.id()] = tm } -func marshalTarget(target stack.Target) []byte { - targetMaker, ok := targetMakers[target.ID()] +func marshalTarget(tgt stack.Target) []byte { + // The sentry only uses targets, never stack.Targets directly. + target := tgt.(target) + targetMaker, ok := targetMakers[target.id()] if !ok { - panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id())) } return targetMaker.marshal(target) } -func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { - tid := stack.TargetID{ - Name: target.Name.String(), - NetworkProtocol: filter.NetworkProtocol(), - Revision: target.Revision, +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) { + tid := targetID{ + name: target.Name.String(), + networkProtocol: filter.NetworkProtocol(), + revision: target.Revision, } targetMaker, ok := targetMakers[tid] if !ok { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index b560fae0d..70c561cce 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), false) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct. - entries, info := getEntries4(table, tablename) + entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 4253f7bf4..5dbb604f0 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), true) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct, which is the same in IPv4 and IPv6. - entries, info := getEntries6(table, tablename) + entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 904a12e38..b283d7229 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) { } } +// Table names. +const ( + natTable = "nat" + mangleTable = "mangle" + filterTable = "filter" +) + +// nameToID is immutable. +var nameToID = map[string]stack.TableID{ + natTable: stack.NATID, + mangleTable: stack.MangleID, + filterTable: stack.FilterID, +} + +// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for +// compatibility with netfilter extensions. +func DefaultLinuxTables() *stack.IPTables { + tables := stack.DefaultTables() + tables.VisitTargets(func(oldTarget stack.Target) stack.Target { + switch val := oldTarget.(type) { + case *stack.AcceptTarget: + return &acceptTarget{AcceptTarget: *val} + case *stack.DropTarget: + return &dropTarget{DropTarget: *val} + case *stack.ErrorTarget: + return &errorTarget{ErrorTarget: *val} + case *stack.UserChainTarget: + return &userChainTarget{UserChainTarget: *val} + case *stack.ReturnTarget: + return &returnTarget{ReturnTarget: *val} + case *stack.RedirectTarget: + return &redirectTarget{RedirectTarget: *val} + default: + panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val)) + } + }) + return tables +} + // GetInfo returns information about iptables. func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. @@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.FilterTable: + case filterTable: table = stack.EmptyFilterTable() - case stack.NATTable: + case natTable: table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) @@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } if offset == replace.Underflow[hook] { if !validUnderflow(table.Rules[ruleIdx], ipv6) { - nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx) return syserr.ErrInvalidArgument } table.Underflows[hk] = ruleIdx @@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) - + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case *stack.AcceptTarget, *stack.DropTarget: + case *acceptTarget, *dropTarget: return true default: return false @@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(*stack.AcceptTarget) + _, ok := rule.Target.(*acceptTarget) return ok } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 0e14447fe..f2653d523 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -26,6 +26,15 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port and/or IP for packets. +const RedirectTargetName = "REDIRECT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -52,25 +61,96 @@ func init() { }) } +// The stack package provides some basic, useful targets for us. The following +// types wrap them for compatibility with the extension system. + +type acceptTarget struct { + stack.AcceptTarget +} + +func (at *acceptTarget) id() targetID { + return targetID{ + networkProtocol: at.NetworkProtocol, + } +} + +type dropTarget struct { + stack.DropTarget +} + +func (dt *dropTarget) id() targetID { + return targetID{ + networkProtocol: dt.NetworkProtocol, + } +} + +type errorTarget struct { + stack.ErrorTarget +} + +func (et *errorTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: et.NetworkProtocol, + } +} + +type userChainTarget struct { + stack.UserChainTarget +} + +func (uc *userChainTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: uc.NetworkProtocol, + } +} + +type returnTarget struct { + stack.ReturnTarget +} + +func (rt *returnTarget) id() targetID { + return targetID{ + networkProtocol: rt.NetworkProtocol, + } +} + +type redirectTarget struct { + stack.RedirectTarget + + // addr must be (un)marshalled when reading and writing the target to + // userspace, but does not affect behavior. + addr tcpip.Address +} + +func (rt *redirectTarget) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rt.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (sm *standardTargetMaker) id() stack.TargetID { +func (sm *standardTargetMaker) id() targetID { // Standard targets have the empty string as a name and no revisions. - return stack.TargetID{ - NetworkProtocol: sm.NetworkProtocol, + return targetID{ + networkProtocol: sm.NetworkProtocol, } } -func (*standardTargetMaker) marshal(target stack.Target) []byte { + +func (*standardTargetMaker) marshal(target target) []byte { // Translate verdicts the same way as the iptables tool. var verdict int32 switch tg := target.(type) { - case *stack.AcceptTarget: + case *acceptTarget: verdict = -linux.NF_ACCEPT - 1 - case *stack.DropTarget: + case *dropTarget: verdict = -linux.NF_DROP - 1 - case *stack.ReturnTarget: + case *returnTarget: verdict = linux.NF_RETURN case *JumpTarget: verdict = int32(tg.Offset) @@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTStandardTarget { nflog("buf has wrong size for standard target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -114,20 +194,20 @@ type errorTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (em *errorTargetMaker) id() stack.TargetID { +func (em *errorTargetMaker) id() targetID { // Error targets have no revision. - return stack.TargetID{ - Name: stack.ErrorTargetName, - NetworkProtocol: em.NetworkProtocol, + return targetID{ + name: ErrorTargetName, + networkProtocol: em.NetworkProtocol, } } -func (*errorTargetMaker) marshal(target stack.Target) []byte { +func (*errorTargetMaker) marshal(target target) []byte { var errorName string switch tg := target.(type) { - case *stack.ErrorTarget: - errorName = stack.ErrorTargetName - case *stack.UserChainTarget: + case *errorTarget: + errorName = ErrorTargetName + case *userChainTarget: errorName = tg.Name default: panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) @@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte { }, } copy(xt.Name[:], errorName) - copy(xt.Target.Name[:], stack.ErrorTargetName) + copy(xt.Target.Name[:], ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTErrorTarget { nflog("buf has insufficient size for error target %d", len(buf)) return nil, syserr.ErrInvalidArgument } - var errorTarget linux.XTErrorTarget + var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &errTgt) // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named stack.ErrorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. + // * An actual error case. These rules have an error named + // ErrorTargetName. The last entry of the table is usually an error + // case to catch any packets that somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case stack.ErrorTargetName: - return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + switch name := errTgt.Name.String(); name { + case ErrorTargetName: + return &errorTarget{stack.ErrorTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }}, nil default: // User defined chain. - return &stack.UserChainTarget{ + return &userChainTarget{stack.UserChainTarget{ Name: name, NetworkProtocol: filter.NetworkProtocol(), - }, nil + }}, nil } } @@ -178,22 +259,22 @@ type redirectTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *redirectTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *redirectTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*redirectTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*redirectTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) // This is a redirect target named redirect xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(xt.Target.Name[:], stack.RedirectTargetName) + copy(xt.Target.Name[:], RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 @@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) < linux.SizeOfXTRedirectTarget { nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - var redirectTarget linux.XTRedirectTarget + var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &rt) // Copy linux.XTRedirectTarget to stack.RedirectTarget. - target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + target := redirectTarget{RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} // RangeSize should be 1. - nfRange := redirectTarget.NfRange + nfRange := rt.NfRange if nfRange.RangeSize != 1 { nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) return nil, syserr.ErrInvalidArgument @@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) target.Port = ntohs(nfRange.RangeIPV4.MinPort) return &target, nil @@ -264,15 +347,15 @@ type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *nfNATTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *nfNATTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*nfNATTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*nfNATTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ TargetSize: nfNATMarhsalledSize, @@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, } - copy(nt.Target.Name[:], stack.RedirectTargetName) - copy(nt.Range.MinAddr[:], rt.Addr) - copy(nt.Range.MaxAddr[:], rt.Addr) + copy(nt.Target.Name[:], RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.addr) + copy(nt.Range.MaxAddr[:], rt.addr) nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto @@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, nt) } -func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if size := nfNATMarhsalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument @@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta return nil, syserr.ErrInvalidArgument } - target := stack.RedirectTarget{ - NetworkProtocol: filter.NetworkProtocol(), - Addr: tcpip.Address(natRange.MinAddr[:]), - Port: ntohs(natRange.MinProto), + target := redirectTarget{ + RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Port: ntohs(natRange.MinProto), + }, + addr: tcpip.Address(natRange.MinAddr[:]), } return &target, nil @@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return &stack.AcceptTarget{NetworkProtocol: netProto}, nil + return &acceptTarget{stack.AcceptTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_DROP - 1: - return &stack.DropTarget{NetworkProtocol: netProto}, nil + return &dropTarget{stack.DropTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_QUEUE - 1: nflog("unsupported iptables verdict QUEUE") return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return &stack.ReturnTarget{NetworkProtocol: netProto}, nil + return &returnTarget{stack.ReturnTarget{ + NetworkProtocol: netProto, + }}, nil default: nflog("unknown iptables verdict %d", val) return nil, syserr.ErrInvalidArgument @@ -382,9 +473,9 @@ type JumpTarget struct { } // ID implements Target.ID. -func (jt *JumpTarget) ID() stack.TargetID { - return stack.TargetID{ - NetworkProtocol: jt.NetworkProtocol, +func (jt *JumpTarget) id() targetID { + return targetID{ + networkProtocol: jt.NetworkProtocol, } } diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 22216158e..f4d034c13 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -487,7 +487,7 @@ func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlin Addr: value, }) if err != nil { - return syserr.ErrInvalidArgument + return syserr.ErrBadLocalAddress } case linux.IFA_ADDRESS: default: diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index aa4f3c04d..6d9e502bd 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -142,9 +142,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E } q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} - q1.EnableLeakCheck() + q1.InitRefs() q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - q2.EnableLeakCheck() + q2.InitRefs() if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} @@ -300,14 +300,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} - readQueue.EnableLeakCheck() + readQueue.InitRefs() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - writeQueue.EnableLeakCheck() + writeQueue.InitRefs() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index f8aacca13..1406971bc 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -42,7 +42,7 @@ var ( func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} - q.EnableLeakCheck() + q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} return ep } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index adad485a9..b32bb7ba8 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -80,7 +80,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty stype: stype, }, } - s.EnableLeakCheck() + s.InitRefs() return fs.NewFile(ctx, d, flags, &s) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 7a78444dc..eaf0b0d26 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -80,7 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } - sock.EnableLeakCheck() + sock.InitRefs() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 36902d177..bb1f715e2 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 849a47476..f7135ea46 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return 0, syserror.EINVAL } - r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) + r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize) r.SetFlags(linuxToFlags(flags).Settable()) defer r.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index c2d4bf805..e383a0a87 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -138,12 +138,18 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, err + case linux.GETZCNT: + v, err := getZCnt(t, id, num) + return uintptr(v), nil, err + + case linux.GETNCNT: + v, err := getNCnt(t, id, num) + return uintptr(v), nil, err + case linux.IPC_INFO, linux.SEM_INFO, linux.SEM_STAT, - linux.SEM_STAT_ANY, - linux.GETNCNT, - linux.GETZCNT: + linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -258,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) { } return int32(tg.ID()), nil } + +func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountZeroWaiters(num, creds) +} + +func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountNegativeWaiters(num, creds) +} diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 46616c961..1c4cdb0dd 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inCh chan struct{} outCh chan struct{} ) + for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) opts.Length -= n @@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inW, _ := waiter.NewChannelEntry(inCh) inFile.EventRegister(&inW, EventMaskRead) defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Need to refresh readiness. + continue } if err = t.Block(inCh); err != nil { break } } - if outFile.Readiness(EventMaskWrite) == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, EventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + // We might be ready to write now. Try again before + // blocking. + continue + } + if err = t.Block(outCh); err != nil { + break } } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 035e2a6b0..9ce4f280a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { // waitForOut waits for dw.outfile to be read. func (dw *dualWaiter) waitForOut(t *kernel.Task) error { - if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if dw.outCh == nil { - dw.outW, dw.outCh = waiter.NewChannelEntry(nil) - dw.outFile.EventRegister(&dw.outW, eventMaskWrite) - // We might be ready now. Try again before blocking. - return nil - } - if err := t.Block(dw.outCh); err != nil { - return err - } - } - return nil + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. See b/172075629, b/170743336. + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready to write now. Try again before blocking. + return nil + } + return t.Block(dw.outCh) } // destroy cleans up resources help by dw. No more calls to wait* can occur diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 546e445aa..936f9fc71 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -133,7 +133,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou } } - fd.EnableLeakCheck() + fd.InitRefs() // Remove "file creation flags" to mirror the behavior from file.f_flags in // fs/open.c:do_dentry_open. diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index c93d94634..2c4b81e78 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -48,7 +48,7 @@ type Filesystem struct { // Init must be called before first use of fs. func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) { - fs.EnableLeakCheck() + fs.InitRefs() fs.vfs = vfsObj fs.fsType = fsType fs.impl = impl diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 3f0b8f45b..107171b61 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -65,7 +65,7 @@ type Inotify struct { // queue is used to notify interested parties when the inotify instance // becomes readable or writable. - queue waiter.Queue `state:"nosave"` + queue waiter.Queue // evMu *only* protects the events list. We need a separate lock while // queuing events: using mu may violate lock ordering, since at that point diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 3ea981ad4..d865fd603 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -169,7 +169,7 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth Owner: creds.UserNamespace, mountpoints: make(map[*Dentry]uint32), } - mntns.EnableLeakCheck() + mntns.InitRefs() mntns.root = newMount(vfs, fs, root, mntns, opts) return mntns, nil } @@ -477,7 +477,9 @@ func (mnt *Mount) tryIncMountedRef() bool { return false } if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) { - refsvfs2.LogTryIncRef(mnt, r+1) + if mnt.LogRefs() { + refsvfs2.LogTryIncRef(mnt, r+1) + } return true } } @@ -488,12 +490,17 @@ func (mnt *Mount) IncRef() { // In general, negative values for mnt.refs are valid because the MSB is // the eager-unmount bit. r := atomic.AddInt64(&mnt.refs, 1) - refsvfs2.LogIncRef(mnt, r) + if mnt.LogRefs() { + refsvfs2.LogIncRef(mnt, r) + } } // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { r := atomic.AddInt64(&mnt.refs, -1) + if mnt.LogRefs() { + refsvfs2.LogDecRef(mnt, r) + } if r&^math.MinInt64 == 0 { // mask out MSB refsvfs2.Unregister(mnt) mnt.destroy(ctx) diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index cb8c56bd3..cb882a983 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) { parent := &Mount{} point := &Dentry{} if m := mt.Lookup(parent, point); m != nil { - t.Errorf("empty mountTable lookup: got %p, wanted nil", m) + t.Errorf("Empty mountTable lookup: got %p, wanted nil", m) } } @@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] mi, ok := ms.Load(k) if !ok { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } m := mi.(*Mount) if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m, _ := ms.Load(k) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD index f08599ebd..cb0001852 100644 --- a/pkg/shim/runsc/BUILD +++ b/pkg/shim/runsc/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "@com_github_containerd_containerd//log:go_default_library", "@com_github_containerd_go_runc//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index c5cf68efa..e7c9640ba 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -28,10 +28,12 @@ import ( "syscall" "time" + "github.com/containerd/containerd/log" runc "github.com/containerd/go-runc" specs "github.com/opencontainers/runtime-spec/specs-go" ) +// Monitor is the default process monitor to be used by runsc. var Monitor runc.ProcessMonitor = runc.Monitor // DefaultCommand is the default command for Runsc. @@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro return &c, nil } +// CreateOpts is a set of options to Runsc.Create(). type CreateOpts struct { runc.IO ConsoleSocket runc.ConsoleSocket @@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) { return res.ExitStatus, nil } +// ExecOpts is a set of options to runsc.Exec(). type ExecOpts struct { runc.IO PidFile string @@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts return Monitor.Wait(cmd, ec) } +// DeleteOpts is a set of options to runsc.Delete(). type DeleteOpts struct { Force bool } @@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { if err := json.NewDecoder(rd).Decode(&e); err != nil { return nil, err } + log.L.Debugf("Stats returned: %+v", e.Stats) + if e.Type != "stats" { + return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type) + } + if e.Stats == nil { + return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`) + } return e.Stats, nil } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 09cb153b1..4e7e5f76a 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -375,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback +// address. +func IsV6LoopbackAddress(addr tcpip.Address) bool { + return addr == IPv6Loopback +} + // IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 // link-local multicast address. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 560477926..b3e8c4b92 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // // We don't clone the original packet buffer so that the new packet buffer // does not have any of its headers set. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) + // + // We trim the link headers from the cloned buffer as the sniffer doesn't + // handle link headers. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + vv.TrimFront(len(pkt.LinkHeader().View())) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) switch protocol { case header.IPv4ProtocolNumber: if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index cda6328a2..4c14f55d3 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -157,7 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } - endpoint.EnableLeakCheck() + endpoint.InitRefs() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index a79379abb..33a4a0720 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -122,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } @@ -145,7 +145,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } @@ -158,6 +158,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 969579601..8873bd91f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { + netHdr := pkt.Network() + t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -608,7 +609,8 @@ func TestIPv4Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -707,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -788,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } @@ -800,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -900,7 +906,8 @@ func TestIPv6Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -1017,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -1071,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum tcpip.NetworkProtocolNumber nicAddr tcpip.Address remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.View + pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) expectedErr *tcpip.Error }{ @@ -1081,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1095,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1123,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1137,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1147,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1156,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1166,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1175,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1203,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) @@ -1221,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) } - return hdr.View() + return hdr.View().ToVectorisedView() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with options and data across views", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + vv := buffer.View(ip).ToVectorisedView() + vv.AppendView(ipv4Options) + vv.AppendView(data) + return vv }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1251,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1264,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1291,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1307,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1334,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1342,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1369,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1377,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1421,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { defer r.Release() if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + Data: test.pktGen(t, subTest.srcAddr), })); err != test.expectedErr { t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index cf287446e..9b5e37fee 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -42,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { return } @@ -58,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { - stats := r.Stats() +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() received := stats.ICMP.V4PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a @@ -83,7 +83,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // packets with checksum errors. switch h.Type() { case header.ICMPv4Echo: - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) } return } @@ -106,7 +106,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } else { op = &optionUsageReceive{} } - aux, tmp, err := processIPOptions(r, iph.Options(), op) + aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) if err != nil { switch { case @@ -116,9 +116,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidLength), errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() } return } @@ -131,7 +131,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { received.Echo.Increment() sent := stats.ICMP.V4PacketsSent - if !r.Stack().AllowICMPMessage() { + if !e.protocol.stack.AllowICMPMessage() { sent.RateLimited.Increment() return } @@ -144,10 +144,13 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast // It's possible that a raw socket expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) pkt = nil + // Take the base of the incoming request IP header but replace the options. replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) @@ -156,12 +159,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). - localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr := ipHdr.DestinationAddress() + if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -218,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() @@ -307,7 +310,11 @@ func (*icmpReasonParamProblem) isICMPReason() {} // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv4(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -331,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in // response to a non-initial fragment, but it currently can not happen. - - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { return nil } @@ -340,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil sent := p.stack.Stats().ICMP.V4PacketsSent if !p.stack.AllowICMPMessage() { @@ -355,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. - if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { // TODO(gvisor.dev/issue/3810): // Unfortunately the current stack pretty much always has ICMPv4 headers // in the Data section of the packet but there is no guarantee that is the @@ -416,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -428,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // view with the entire incoming IP packet reassembled and truncated as // required. This is now the payload of the new ICMP packet and no longer // considered a packet in its own right. - newHeader := append(buffer.View(nil), networkHeader...) + newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 4592984a5..cfd0c505a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -252,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() return nil @@ -270,16 +269,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, pkt) - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil @@ -373,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -403,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return tcpip.ErrMalformedHeader } + + hdrLen := header.IPv4(h).HeaderLength() + if hdrLen < header.IPv4MinimumSize { + return tcpip.ErrMalformedHeader + } + + h, ok = pkt.Data.PullUp(int(hdrLen)) + if !ok { + return tcpip.ErrMalformedHeader + } ip := header.IPv4(h) // Always set the total length. @@ -447,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -480,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -488,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -498,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -506,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -520,8 +545,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } @@ -537,12 +562,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { var releaseCB func(bool) if start == 0 { pkt := pkt.Clone() - r := r.Clone() releaseCB = func(timedOut bool) { if timedOut { - _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) } - r.Release() } } @@ -566,8 +589,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if !ready { @@ -579,7 +602,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) h.SetFlagsFragmentOffset(0, 0) } - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { @@ -587,14 +610,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // headers, the setting of the transport number here should be // unnecessary and removed. pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt) + e.handleICMP(pkt) return } if len(h.Options()) != 0 { // TODO(gvisor.dev/issue/4586): // When we add forwarding support we should use the verified options // rather than just throwing them away. - aux, _, err := processIPOptions(r, h.Options(), &optionUsageReceive{}) + aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) if err != nil { switch { case @@ -604,15 +627,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { errors.Is(err, errIPv4TimestampOptInvalidLength), errors.Is(err, errIPv4TimestampOptInvalidPointer), errors.Is(err, errIPv4TimestampOptOverflow): - _ = e.protocol.returnError(r, &icmpReasonParamProblem{pointer: aux}, pkt) - e.protocol.stack.Stats().MalformedRcvdPackets.Increment() - r.Stats().IP.MalformedPacketsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() } return } } - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination @@ -620,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -669,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo loopback := e.nic.IsLoopback() addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { - subnet := addressEndpoint.AddressWithPrefix().Subnet() + subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) @@ -919,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head originalIPHeaderLength := len(originalIPHeader) nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) @@ -1172,8 +1196,8 @@ func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Ad // - The location of an error if there was one (or 0 if no error) // - If there is an error, information as to what it was was. // - The replacement option set. -func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { - +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + stats := e.protocol.stack.Stats() opts := header.IPv4Options(orig) optIter := opts.MakeIterator() @@ -1186,13 +1210,15 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag // This will need tweaking when we start really forwarding packets // as we may need to get two addresses, for rx and tx interfaces. // We will also have to take usage into account. - prefixedAddress, err := r.Stack().GetMainNICAddress(r.NICID(), ProtocolNumber) + prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) localAddress := prefixedAddress.Address if err != nil { - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress } - localAddress = r.LocalAddress + localAddress = dstAddr } for { @@ -1219,9 +1245,9 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag optLen := int(option.Size()) switch option := option.(type) { case *header.IPv4OptionTimestamp: - r.Stats().IP.OptionTSReceived.Increment() + stats.IP.OptionTSReceived.Increment() if usage.actions().timestamp != optionRemove { - clock := r.Stack().Clock() + clock := e.protocol.stack.Clock() newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) @@ -1232,7 +1258,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag } case *header.IPv4OptionRecordRoute: - r.Stats().IP.OptionRRReceived.Increment() + stats.IP.OptionRRReceived.Increment() if usage.actions().recordRoute != optionRemove { newBuffer := optIter.RemainingBuffer()[:len(*option)] _ = copy(newBuffer, option.Contents()) @@ -1244,7 +1270,7 @@ func processIPOptions(r *stack.Route, orig header.IPv4Options, usage optionsUsag } default: - r.Stats().IP.OptionUnknownReceived.Increment() + stats.IP.OptionUnknownReceived.Increment() if usage.actions().unknown == optionPass { newBuffer := optIter.RemainingBuffer()[:optLen] // Arguments already heavily checked.. ignore result. diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 61672a5ff..c7f434591 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2178,13 +2178,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -2199,17 +2196,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 3c15e41a7..8502b848c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { }) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := r.Stats().ICMP +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { + stats := e.protocol.stack.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their @@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) + srcAddr := iph.SourceAddress() + dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. // // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { received.Invalid.Increment() return } @@ -224,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // we know we are also performing DAD on it). In this case we let the // stack know so it can handle such a scenario and do nothing further with // the NS. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -251,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -277,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. - unspecifiedSource := r.RemoteAddress == header.IPv6Any + unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { - if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { received.Invalid.Increment() return } @@ -287,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -298,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // ... // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. - if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { received.Invalid.Increment() return } @@ -308,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the source of the solicitation is the unspecified address, the node // MUST [...] and multicast the advertisement to the all-nodes address. // - remoteAddr := r.RemoteAddress + remoteAddr := srcAddr if unspecifiedSource { remoteAddr = header.IPv6AllNodesMulticastAddress } @@ -465,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. - localAddr := r.LocalAddress - if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr := dstAddr + if header.IsV6MulticastAddress(dstAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -486,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -498,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -519,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - stack := r.Stack() + stack := e.protocol.stack // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { @@ -550,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { received.Invalid.Increment() return } @@ -558,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -575,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - routerAddr := iph.SourceAddress() + routerAddr := srcAddr // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { @@ -608,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -753,7 +759,11 @@ func (*icmpReasonReassemblyTimeout) isICMPReason() {} // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv6(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -780,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { return nil } @@ -788,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil stats := p.stack.Stats().ICMP sent := stats.V6PacketsSent diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index aa8b5f2e5..76013daa1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -87,7 +87,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { +func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { return stack.TransportPacketHandled } @@ -282,7 +282,8 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -424,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -1796,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) { SrcAddr: r.RemoteAddress, DstAddr: r.LocalAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. if nudHandler.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 1e38f3a9d..0526190cc 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return err } - prefix := addressEndpoint.AddressWithPrefix().Subnet() + prefix := addressEndpoint.Subnet() switch t := addressEndpoint.ConfigType(); t { case stack.AddressConfigStatic: @@ -465,21 +465,27 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ - // The inbound path expects an unparsed packet. - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) - - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil @@ -576,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -637,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() // As per RFC 4291 section 2.7: // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. - if header.IsV6MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if header.IsV6MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -671,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -681,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -693,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -705,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -719,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -732,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -757,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -794,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if done { @@ -822,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } } @@ -831,8 +844,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } @@ -845,9 +858,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // of the fragment, pointing to the Payload Length field of the // fragment packet. if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: header.IPv6PayloadLenOffset, }, pkt) @@ -866,9 +879,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // the fragment, pointing to the Fragment Offset field of the fragment // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: fragmentFieldOffset, }, pkt) @@ -880,12 +893,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { var releaseCB func(bool) if start == 0 { pkt := pkt.Clone() - r := r.Clone() releaseCB = func(timedOut bool) { if timedOut { - _ = e.protocol.returnError(&r, &icmpReasonReassemblyTimeout{}, pkt) + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) } - r.Release() } } @@ -895,8 +906,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ - Source: h.SourceAddress(), - Destination: h.DestinationAddress(), + Source: srcAddr, + Destination: dstAddr, ID: extHdr.ID(), }, start, @@ -907,8 +918,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } pkt.Data = data @@ -927,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -941,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -954,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -977,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt, hasFragmentHeader) + e.handleICMP(pkt, hasFragmentHeader) } else { - r.Stats().IP.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + stats.IP.PacketsDelivered.Increment() + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -991,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -1012,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -1022,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) - r.Stats().UnknownProtocolRcvdPackets.Increment() + stats.UnknownProtocolRcvdPackets.Increment() return } } @@ -1635,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea originalIPHeadersLength := len(originalIPHeaders) fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index c593c0004..1bfcdde25 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -2360,13 +2360,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2381,17 +2378,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 7f2ebc0cb..981d1371a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) @@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } var tllData [header.NDPLinkLayerAddressSize]byte diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 261705575..9478f3fb7 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState = &addressState{ addressableEndpointState: a, addr: addr, + // Cache the subnet in addrState to avoid calls to addr.Subnet() as that + // results in allocations on every call. + subnet: addr.Subnet(), } a.mu.endpoints[addr.Address] = addrState addrState.mu.Lock() @@ -666,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil) type addressState struct { addressableEndpointState *AddressableEndpointState addr tcpip.AddressWithPrefix - + subnet tcpip.Subnet // Lock ordering (from outer to inner lock ordering): // // AddressableEndpointState.mu @@ -686,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr } +// Subnet implements AddressEndpoint. +func (a *addressState) Subnet() tcpip.Subnet { + return a.subnet +} + // GetKind implements AddressEndpoint. func (a *addressState) GetKind() AddressKind { a.mu.RLock() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 0cd1da11f..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.Addr - replyTID.srcPort = rt.Port + replyTID.srcAddr = address + replyTID.srcPort = port var manip manipType switch hook { case Prerouting: @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index 380688038..7a501acdc 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 8d6d9a7f1..2d8c883cd 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -22,30 +22,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -// tableID is an index into IPTables.tables. -type tableID int +// TableID identifies a specific table. +type TableID int +// Each value identifies a specific table. const ( - natID tableID = iota - mangleID - filterID - numTables + NATID TableID = iota + MangleID + FilterID + NumTables ) -// Table names. -const ( - NATTable = "nat" - MangleTable = "mangle" - FilterTable = "filter" -) - -// nameToID is immutable. -var nameToID = map[string]tableID{ - NATTable: natID, - MangleTable: mangleID, - FilterTable: filterID, -} - // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 @@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - v4Tables: [numTables]Table{ - natID: Table{ + v4Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -81,7 +68,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -99,7 +86,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -122,8 +109,8 @@ func DefaultTables() *IPTables { }, }, }, - v6Tables: [numTables]Table{ - natID: Table{ + v6Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -146,7 +133,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -164,7 +151,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -187,10 +174,10 @@ func DefaultTables() *IPTables { }, }, }, - priorities: [NumHooks][]tableID{ - Prerouting: []tableID{mangleID, natID}, - Input: []tableID{natID, filterID}, - Output: []tableID{mangleID, natID, filterID}, + priorities: [NumHooks][]TableID{ + Prerouting: []TableID{MangleID, NATID}, + Input: []TableID{NATID, FilterID}, + Output: []TableID{MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -229,26 +216,20 @@ func EmptyNATTable() Table { } } -// GetTable returns a table by name. -func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - id, ok := nameToID[name] - if !ok { - return Table{}, false - } +// GetTable returns a table with the given id and IP version. It panics when an +// invalid id is provided. +func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { it.mu.RLock() defer it.mu.RUnlock() if ipv6 { - return it.v6Tables[id], true + return it.v6Tables[id] } - return it.v4Tables[id], true + return it.v4Tables[id] } -// ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - id, ok := nameToID[name] - if !ok { - return tcpip.ErrInvalidOptionValue - } +// ReplaceTable replaces or inserts table by name. It panics when an invalid id +// is provided. +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to // check the NAT table. - if tableID == natID && pkt.NatDone { + if tableID == NATID && pkt.NatDone { continue } var table Table diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 538c4625d..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -26,13 +28,6 @@ type AcceptTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (at *AcceptTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: at.NetworkProtocol, - } -} - // Action implements Target.Action. func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 @@ -44,22 +39,11 @@ type DropTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (dt *DropTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: dt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } -// ErrorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const ErrorTargetName = "ERROR" - // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. type ErrorTarget struct { @@ -67,14 +51,6 @@ type ErrorTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (et *ErrorTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: et.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") @@ -90,14 +66,6 @@ type UserChainTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (uc *UserChainTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: uc.NetworkProtocol, - } -} - // Action implements Target.Action. func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") @@ -110,50 +78,39 @@ type ReturnTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *ReturnTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } -// RedirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const RedirectTargetName = "REDIRECT" - -// RedirectTarget redirects the packet by modifying the destination port/IP. +// RedirectTarget redirects the packet to this machine by modifying the +// destination port/IP. Outgoing packets are redirected to the loopback device, +// and incoming packets are redirected to the incoming interface (rather than +// forwarded). +// // TODO(gvisor.dev/issue/170): Other flags need to be added after we support // them. type RedirectTarget struct { - // Addr indicates address used to redirect. - Addr tcpip.Address - - // Port indicates port used to redirect. + // Port indicates port used to redirect. It is immutable. Port uint16 - // NetworkProtocol is the network protocol the target is used with. + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *RedirectTarget) ID() TargetID { - return TargetID{ - Name: RedirectTargetName, - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither +// implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + address = tcpip.Address([]byte{127, 0, 0, 1}) } else { - rt.Addr = header.IPv6Loopback + address = header.IPv6Loopback } case Prerouting: - rt.Addr = address + // No-op, as address is already set correctly. default: panic("redirect target is supported only on output and prerouting hooks") } @@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(rt.Addr) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7b3f3e88b..4b86c1be9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -37,7 +37,6 @@ import ( // ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> type Hook uint -// These values correspond to values in include/uapi/linux/netfilter.h. const ( // Prerouting happens before a packet is routed to applications or to // be forwarded. @@ -86,8 +85,8 @@ type IPTables struct { mu sync.RWMutex // v4Tables and v6tables map tableIDs to tables. They hold builtin // tables only, not user tables. mu must be locked for accessing. - v4Tables [numTables]Table - v6Tables [numTables]Table + v4Tables [NumTables]Table + v6Tables [NumTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. @@ -96,7 +95,7 @@ type IPTables struct { // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. - priorities [NumHooks][]tableID + priorities [NumHooks][]TableID connections ConnTrack @@ -104,6 +103,24 @@ type IPTables struct { reaperDone chan struct{} } +// VisitTargets traverses all the targets of all tables and replaces each with +// transform(target). +func (it *IPTables) VisitTargets(transform func(Target) Target) { + it.mu.Lock() + defer it.mu.Unlock() + + for tid := range it.v4Tables { + for i, rule := range it.v4Tables[tid].Rules { + it.v4Tables[tid].Rules[i].Target = transform(rule.Target) + } + } + for tid := range it.v6Tables { + for i, rule := range it.v6Tables[tid].Rules { + it.v6Tables[tid].Rules[i].Target = transform(rule.Target) + } + } +} + // A Table defines a set of chains and hooks into the network stack. // // It is a list of Rules, entry points (BuiltinChains), and error handlers @@ -169,7 +186,6 @@ type IPHeaderFilter struct { // CheckProtocol determines whether the Protocol field should be // checked during matching. - // TODO(gvisor.dev/issue/3549): Check this field during matching. CheckProtocol bool // Dst matches the destination IP address. @@ -309,23 +325,8 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } -// A TargetID uniquely identifies a target. -type TargetID struct { - // Name is the target name as stored in the xt_entry_target struct. - Name string - - // NetworkProtocol is the protocol to which the target applies. - NetworkProtocol tcpip.NetworkProtocolNumber - - // Revision is the version of the target. - Revision uint8 -} - // A Target is the interface for taking an action for a packet. type Target interface { - // ID uniquely identifies the Target. - ID() TargetID - // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index aec77610d..493e48031 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -17,12 +17,19 @@ package stack import ( "fmt" "sync" + "time" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) +const ( + // immediateDuration is a duration of zero for scheduling work that needs to + // be done immediately but asynchronously to avoid deadlock. + immediateDuration time.Duration = 0 +) + // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { Addr tcpip.Address @@ -242,7 +249,12 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { e.job.Schedule(config.RetransmitTimer) } - sendUnicastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(immediateDuration) case Failed: e.notifyWakersLocked() @@ -324,7 +336,12 @@ func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { e.job.Schedule(config.RetransmitTimer) } - sendMulticastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(immediateDuration) case Stale: e.setStateLocked(Delay) diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index d297c9422..c2b763325 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -47,6 +47,12 @@ const ( entryTestNetDefaultMTU = 65536 ) +// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { + clock.Advance(immediateDuration) +} + // eventDiffOpts are the options passed to cmp.Diff to compare entry events. // The UpdatedAtNanos field is ignored due to a lack of a deterministic method // to predict the time that an event will be dispatched. @@ -308,7 +314,7 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { func TestEntryUnknownToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) @@ -317,6 +323,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -354,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { func TestEntryUnknownToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) @@ -364,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Unlock() // No probes should have been sent. + runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) linkRes.mu.Unlock() @@ -488,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { func TestEntryIncompleteToReachable(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -519,6 +520,17 @@ func TestEntryIncompleteToReachable(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -552,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { // to Reachable. func TestEntryAddsAndClearsWakers(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) w := sleep.Waker{} s := sleep.Sleeper{} @@ -561,6 +573,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -584,20 +614,6 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -627,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -660,6 +666,20 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { } linkRes.mu.Unlock() + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -689,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { func TestEntryIncompleteToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -720,6 +733,17 @@ func TestEntryIncompleteToStale(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -830,12 +854,30 @@ func (*testLocker) Unlock() {} func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -861,20 +903,6 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -910,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -945,6 +959,24 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -983,16 +1015,9 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1007,6 +1032,17 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ @@ -1053,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1085,6 +1110,21 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1119,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1165,6 +1184,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1199,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1245,6 +1262,25 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1279,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1324,6 +1340,24 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1353,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1392,6 +1408,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1430,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1455,20 +1511,6 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1507,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1546,6 +1570,28 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1584,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1619,6 +1651,24 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1657,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { func TestEntryStaleToDelay(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1689,6 +1728,21 @@ func TestEntryStaleToDelay(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1736,21 +1790,9 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleUpperLevelConfirmationLocked() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1765,8 +1807,23 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleUpperLevelConfirmationLocked() + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1833,28 +1890,9 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1869,8 +1907,30 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -1937,6 +1997,24 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1959,22 +2037,7 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2031,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2071,6 +2115,29 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2109,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2142,6 +2197,22 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2189,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked(entryTestAddr2) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2226,6 +2281,26 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2277,6 +2352,27 @@ func TestEntryDelayToProbe(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2289,25 +2385,19 @@ func TestEntryDelayToProbe(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ @@ -2367,6 +2457,27 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2376,25 +2487,19 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2459,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { @@ -2473,6 +2572,27 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2482,25 +2602,19 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2569,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { @@ -2583,6 +2691,27 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2592,25 +2721,20 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2696,9 +2820,7 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - // Probe caused by the Delay-to-Probe transition { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, @@ -2729,7 +2851,6 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2795,6 +2916,27 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2804,25 +2946,19 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2843,7 +2979,6 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -2918,6 +3053,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -2927,25 +3083,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2963,7 +3113,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -3038,6 +3187,27 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -3047,25 +3217,19 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -3083,7 +3247,6 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, @@ -3158,9 +3321,9 @@ func TestEntryProbeToFailed(t *testing.T) { e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) { wantProbes := []entryTestProbeInfo{ - // Caused by the Unknown-to-Incomplete transition. { RemoteAddress: entryTestAddr1, LocalAddress: entryTestAddr2, @@ -3283,6 +3446,26 @@ func TestEntryFailedGetsDeleted(t *testing.T) { e.mu.Lock() e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -3293,33 +3476,28 @@ func TestEntryFailedGetsDeleted(t *testing.T) { waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The next three probe are sent in Probe. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 17f2e6b46..60c81a3aa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -348,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -555,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { } func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + n.getNetworkEndpoint(protocol).HandlePacket(pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -594,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] @@ -669,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.nic + n := r.localAddressNIC if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote + pkt.NICID = n.ID() r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + pkt.NetworkPacketInfo = r.networkPacketInfo() + n.getNetworkEndpoint(protocol).HandlePacket(pkt) addressEndpoint.DecRef() r.Release() return @@ -735,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -747,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -776,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -791,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -895,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 4af04846f..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 203f3b51f..b8f333057 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,28 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // RemoteAddressBroadcast is true if the packet's remote address is a + // broadcast address. + RemoteAddressBroadcast bool + + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +116,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -172,9 +183,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +238,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface { // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix + // Subnet returns the subnet of the endpoint's address. + Subnet() tcpip.Subnet + // IsAssigned returns whether or not the endpoint is considered bound // to its NetworkEndpoint. IsAssigned(allowExpired bool) bool @@ -547,7 +561,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index b76e2d37b..15ff437c7 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,11 +47,16 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +67,144 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + addrWithPrefix := addressEndpoint.AddressWithPrefix() + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + addressEndpoint.DecRef() + return Route{} + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = addrWithPrefix.Address + } + + r := makeRoute( + netProto, + addrWithPrefix.Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, - Loop: loop, + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + localAddressEndpoint: localAddressEndpoint, + outgoingNIC: outgoingNIC, + Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// PopulatePacketInfo populates a packet buffer's packet information fields. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { + if r.local() { + pkt.RXTransportChecksumValidated = true + } + pkt.NetworkPacketInfo = r.networkPacketInfo() +} + +// networkPacketInfo returns the network packet information of the route. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) networkPacketInfo() NetworkPacketInfo { + return NetworkPacketInfo{ + RemoteAddressBroadcast: r.IsOutboundBroadcast(), + LocalAddressBroadcast: r.isInboundBroadcast(), + } +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } @@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } @@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before the this route can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + if r.localAddressEndpoint != nil { + r.localAddressEndpoint.DecRef() + r.localAddressEndpoint = nil } } // Clone clones the route. func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() + if r.localAddressEndpoint != nil { + if !r.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) + } } return *r } @@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + subnet := r.localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool { return r.isV4Broadcast(r.RemoteAddress) } -// IsInboundBroadcast returns true if the route is for an inbound broadcast +// isInboundBroadcast returns true if the route is for an inbound broadcast // packet. -func (r *Route) IsInboundBroadcast() bool { +func (r *Route) isInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.LocalAddress) } @@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool { // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + localAddressEndpoint: r.localAddressEndpoint, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index e8f1c110e..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -518,6 +519,10 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // IPTables are the initial iptables rules. If nil, iptables will allow + // all traffic. + IPTables *IPTables } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +625,10 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } + if opts.IPTables == nil { + opts.IPTables = DefaultTables() + } + opts.NUDConfigs.resetInvalidFields() s := &Stack{ @@ -633,7 +642,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - tables: DefaultTables(), + tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, @@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -1194,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } + + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1457,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1910,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4eed4ced4..dedfdd435 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + pkt := pkt.Clone() + r.PopulatePacketInfo(pkt) + f.HandlePacket(pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { RemoteAddress: ipv4SubnetBcast, RemoteLinkAddress: header.EthernetBroadcastAddress, NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, + Loop: stack.PacketOut | stack.PacketLoop, }, }, // Broadcast to a locally attached /31 subnet does not populate the @@ -3756,3 +3862,369 @@ func TestRemoveRoutes(t *testing.T) { } } } + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + defer r.Release() + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 6b8071467..c457b67a2 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return + } + route.ResolveWith(pkt.SourceLinkAddress()) + + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + }) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 34aab32d0..9b0f3b675 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -10,6 +10,7 @@ go_test( "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", + "route_test.go", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0dcef7b04..bf7594268 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -33,11 +33,6 @@ import ( func TestForwarding(t *testing.T) { const ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - host1NICID = 1 routerNICID1 = 2 routerNICID2 = 3 @@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) { } }, }, + { + name: "IPv4 host2 server with routerNIC1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 routerNIC2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, } for _, test := range tests { @@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) { if err == tcpip.ErrNoLinkAddress { // Wait for link resolution to complete. <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } else if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) } - if err != nil { t.Fatalf("ep.Write(_, _): %s", err) } @@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress v, _, err := ep.Read(&addr) if err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 6ddcda70c..fe7c1bb3d 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -32,32 +32,36 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +const ( + linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +) - host1IPv4Addr = tcpip.ProtocolAddress{ +var ( + ipv4Addr1 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - host2IPv4Addr = tcpip.ProtocolAddress{ + ipv4Addr2 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr2 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), @@ -89,7 +93,7 @@ func TestPing(t *testing.T) { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, - remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) @@ -104,7 +108,7 @@ func TestPing(t *testing.T) { name: "IPv6 Ping", transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, - remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) @@ -127,7 +131,7 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -143,36 +147,36 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, tcpip.Route{ - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, tcpip.Route{ - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, }) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index e8caf09ba..421da1add 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }, }) - wq := waiter.Queue{} + var wq waiter.Queue rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index f1028823b..cdf0459e3 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - wq := waiter.Queue{} + var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) @@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) { loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") ) - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) - tests := []struct { name string broadcastAddr tcpip.Address @@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, }) + type endpointAndWaiter struct { + ep tcpip.Endpoint + ch chan struct{} + } + var eps []endpointAndWaiter // We create endpoints that bind to both the wildcard address and the // broadcast address to make sure both of these types of "broadcast // interested" endpoints receive broadcast packets. - wq := waiter.Queue{} - var eps []tcpip.Endpoint for _, bindWildcard := range []bool{false, true} { // Create multiple endpoints for each type of "broadcast interested" // endpoint so we can test that all endpoints receive the broadcast // packet. for i := 0; i < 2; i++ { + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) @@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } } - eps = append(eps, ep) + eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) } } @@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - if n, _, err := wep.Write(data, writeOpts); err != nil { + data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) + if n, _, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) } for j, rep := range eps { - if gotPayload, _, err := rep.Read(nil); err != nil { + // Wait for the endpoint to become readable. + <-rep.ch + + if gotPayload, _, err := rep.ep.Read(nil); err != nil { t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go new file mode 100644 index 000000000..02fc47015 --- /dev/null +++ b/pkg/tcpip/tests/integration/route_test.go @@ -0,0 +1,388 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLocalPing tests pinging a remote that is local the stack. +// +// This tests that a local route is created and packets do not leave the stack. +func TestLocalPing(t *testing.T) { + const ( + nicID = 1 + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } + channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { + channelEP := e.(*channel.Endpoint) + if n := channelEP.Drain(); n != 0 { + t.Fatalf("got channelEP.Drain() = %d, want = 0", n) + } + } + + ipv4ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + ipv6ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + linkEndpoint func() stack.LinkEndpoint + localAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + expectedConnectErr *tcpip.Error + checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) + }{ + { + name: "IPv4 loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: ipv4Loopback, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: header.IPv6Loopback, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv4Addr.Address, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv6Addr.Address, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv4 loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } + + if test.expectedConnectErr != nil { + return + } + + payload := tcpip.SlicePayload(test.icmpBuf(t)) + var wOpts tcpip.WriteOptions + if n, _, err := ep.Write(payload, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } + + // Wait for the endpoint to become readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } +} + +// TestLocalUDP tests sending UDP packets between two endpoints that are local +// to the stack. +// +// This tests that that packets never leave the stack and the addresses +// used when sending a packet. +func TestLocalUDP(t *testing.T) { + const ( + nicID = 1 + ) + + tests := []struct { + name string + canBePrimaryAddr tcpip.ProtocolAddress + firstPrimaryAddr tcpip.ProtocolAddress + }{ + { + name: "IPv4", + canBePrimaryAddr: ipv4Addr1, + firstPrimaryAddr: ipv4Addr2, + }, + { + name: "IPv6", + canBePrimaryAddr: ipv6Addr1, + firstPrimaryAddr: ipv6Addr2, + }, + } + + subTests := []struct { + name string + addAddress bool + expectedWriteErr *tcpip.Error + }{ + { + name: "Unassigned local address", + addAddress: false, + expectedWriteErr: tcpip.ErrNoRoute, + }, + { + name: "Assigned local address", + addAddress: true, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + s := stack.New(stackOpts) + ep := channel.New(1, header.IPv6MinimumMTU, "") + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if subTest.addAddress { + if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + } + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer server.Close() + + bindAddr := tcpip.FullAddress{Port: 80} + if err := server.Bind(bindAddr); err != nil { + t.Fatalf("server.Bind(%#v): %s", bindAddr, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventIn) + client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer client.Close() + + serverAddr := tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + Port: 80, + } + + clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &serverAddr, + } + if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + } else if subTest.expectedWriteErr != nil { + // Nothing else to test if we expected not to be able to send the + // UDP packet. + return + } else if n != int64(len(clientPayload)) { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + } + } + + // Wait for the server endpoint to become readable. + <-serverCH + + var clientAddr tcpip.FullAddress + if v, _, err := server.Read(&clientAddr); err != nil { + t.Fatalf("server.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + } + if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { + t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + } + if t.Failed() { + t.FailNow() + } + } + + serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &clientAddr, + } + if n, _, err := server.Write(serverPayload, wOpts); err != nil { + t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) + } else if n != int64(len(serverPayload)) { + t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + } + } + + // Wait for the client endpoint to become readable. + <-clientCH + + var gotServerAddr tcpip.FullAddress + if v, _, err := client.Read(&gotServerAddr); err != nil { + t.Fatalf("client.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + } + if gotServerAddr.Addr != serverAddr.Addr { + t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + } + if t.Failed() { + t.FailNow() + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a17234946..763cd8f84 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, }, } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 87d510f96..3820e5dc7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 79f688129..7b6a87ba9 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { return } + remoteAddr := pkt.Network().SourceAddress() + if e.bound { // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != route.NICID() { + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() return } @@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != route.RemoteAddress { + if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() return } @@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // Push new packet into receive list and increment the buffer size. packet := &rawPacket{ senderAddr: tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress, + NIC: pkt.NICID, + Addr: remoteAddr, }, } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 6b3238d6b..47982ca41 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { - netProto = s.route.NetProto + netProto = s.netProto } + + route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(s.remoteLinkAddr) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6Only n.ID = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.boundNICID = s.nicID + n.route = route + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} n.rcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + if err != nil { + return nil, err + } // Lock the endpoint before registering to ensure that no out of // band changes are possible due to incoming packets etc till @@ -467,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -477,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } switch { @@ -492,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. - return + return nil } ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } else { // If cookies are in use but the endpoint accept queue // is full then drop the syn. @@ -506,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Send SYN without window scaling because we currently // don't encode this information in the cookie. // @@ -523,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, s.route), + MSS: calculateAdvertisedMSS(e.userMSS, route), } - e.sendSynTCP(&s.route, tcpFields{ + fields := tcpFields{ id: s.id, ttl: e.ttl, tos: e.sendTOS, @@ -533,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { seq: cookie, ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, - }, synOpts) + } + if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + return err + } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil } case (s.flags & header.TCPFlagAck) != 0: @@ -547,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } if !ctx.synRcvdCount.synCookiesInUse() { @@ -566,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } iss := s.ackNumber - 1 @@ -587,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -608,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + if err != nil { + return err + } n.mu.Lock() @@ -622,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return nil } // Register new endpoint so that packets are routed to it. @@ -632,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return err } n.isRegistered = true @@ -670,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) + return nil + + default: + return nil } } // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.v6only ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) @@ -714,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { - return nil + return } if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { s := e.segmentQueue.dequeue() - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } close(e.drainDone) @@ -738,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { break } - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 0aaef495d..2facbebec 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { MSS: amss, } if ttl == 0 { - ttl = s.route.DefaultTTL() + ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } // Wait for notification. - index, _ = s.Fetch(true) + h.ep.mu.Unlock() + index, _ = s.Fetch(true /* block */) + h.ep.mu.Lock() } } @@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error { }, synOpts) for h.state != handshakeCompleted { + // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held + // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true) + index, _ := s.Fetch(true /* block */) h.ep.mu.Lock() switch index { @@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // TCP header, then the kernel calculate a checksum of the // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) - } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + } else if r.RequiresTXTransportChecksum() { xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } if ep == nil { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) s.decRef() return } @@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ drained := e.drainDone != nil if drained { close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } // Set up the functions that will be called when the main protocol loop @@ -1535,7 +1541,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} } for _, netProto := range netProtos { - if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil { tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { @@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() switch v { case newSegment: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 98aecab9e..21162f01a 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -172,10 +172,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newSegment(r, id, pkt) - if !s.parse() { + + s := newIncomingSegment(id, pkt) + if !s.parse(pkt.RXTransportChecksumValidated) { ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 560b4904c..a6f25896b 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) { testV6Connect(t, c) } +func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: defaultMTU, + }) + defer c.Cleanup() + + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c826942e9..258f9f1bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -721,9 +721,9 @@ func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned // by another user goroutine if not then we spin, otherwise - // we just goto sleep on the Lock() and wait. + // we just go to sleep on the Lock() and wait. if !e.mu.TryLock() { - // If socket is owned by the user then just goto sleep + // If socket is owned by the user then just go to sleep // as the lock could be held for a reasonably long time. if atomic.LoadUint32(&e.ownedByUser) == 1 { e.mu.Lock() @@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) + s := newOutgoingSegment(e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -2316,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // done yet) or the reservation was freed between the check above and // the FindTransportEndpoint below. But rather than retry the same port // we just skip it and move on. - transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID()) if transEP == nil { // ReservePort failed but there is no registered endpoint with // demuxer. Which indicates there is at least some endpoint that has @@ -2385,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.ID - s.route = r.Clone() e.sndWaker.Assert() } } @@ -2451,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.ID, nil) + s := newOutgoingSegment(e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ // Mark endpoint as closed. @@ -2633,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + + // Expand netProtos to include v4 and v6 under dual-stack if the caller is + // binding to a wildcard (empty) address, and this is an IPv6 endpoint with + // v6only set to false. + if netProto == header.IPv6ProtocolNumber { + stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) + alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + if alsoBindToV4 { + netProtos = append(netProtos, header.IPv4ProtocolNumber) } } @@ -2721,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { } } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -3080,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() { } func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() - } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + } else if e.route.HasSoftwareGSOCapability() { e.gso = &stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b25431467..2bcc5e1c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() { switch { case epState == StateInitial || epState == StateBound: case epState.connected() || epState.handshake(): - if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { - if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + if !e.route.HasSaveRestoreCapability() { + if !e.route.HasDisconncetOkCapability() { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 070b634b4..0664789da 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -30,6 +30,8 @@ import ( // The canonical way of using it is to pass the Forwarder.HandlePacket function // to stack.SetTransportProtocolHandler. type Forwarder struct { + stack *stack.Stack + maxInFlight int handler func(*ForwarderRequest) @@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ + stack: s, maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), @@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newSegment(r, id, pkt) +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newIncomingSegment(id, pkt) defer s.decRef() // We only care about well-formed SYN packets. - if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn { return false } @@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) { delete(r.forwarder.inFlight, r.segment.id) r.forwarder.mu.Unlock() - // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 5bce73605..2329aca4b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(r, ep, id, pkt) +func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(ep, id, pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newSegment(r, id, pkt) +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + s := newIncomingSegment(id, pkt) defer s.decRef() - if !s.parse() || !s.csumValid { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } if !s.flagIsSet(header.TCPFlagRst) { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(p.stack, s, stack.DefaultTOS, 0) } return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment, tos, ttl uint8) { +// +// If the passed TTL is 0, then the route's default TTL will be used. +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { + route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, tcpFields{ + + if ttl == 0 { + ttl = route.DefaultTTL() + } + + return sendTCP(&route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1f9c5cf50..2091989cc 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -45,9 +46,18 @@ type segment struct { ep *endpoint qFlags queueFlags id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + + // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of + // individual members for link/network packet info. + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID + remoteLinkAddr tcpip.LinkAddress + + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -76,11 +86,16 @@ type segment struct { acked bool } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, + remoteLinkAddr: pkt.SourceLinkAddress(), } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB return s } -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, - route: r.Clone(), } s.rcvdTime = time.Now() if len(v) != 0 { @@ -110,7 +124,9 @@ func (s *segment) clone() *segment { ackNumber: s.ackNumber, flags: s.flags, window: s.window, - route: s.route.Clone(), + netProto: s.netProto, + nicID: s.nicID, + remoteLinkAddr: s.remoteLinkAddr, viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, @@ -160,7 +176,6 @@ func (s *segment) decRef() { panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) } } - s.route.Release() } } @@ -198,10 +213,10 @@ func (s *segment) segMemSize() int { // // Returns boolean indicating if the parsing was successful. // -// If checksum verification is not offloaded then parse also verifies the +// If checksum verification may not be skipped, parse also verifies the // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. -func (s *segment) parse() bool { +func (s *segment) parse(skipChecksumValidation bool) bool { // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -220,16 +235,14 @@ func (s *segment) parse() bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - // Query the link capabilities to decide if checksum validation is - // required. verifyChecksum := true - if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + if skipChecksumValidation { s.csumValid = true verifyChecksum = false } if verifyChecksum { s.csum = s.hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6fa8d63cd..ab5fa4fb7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 { + return + } + // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 79646fefe..f791f8f13 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -112,6 +112,18 @@ type Headers struct { TCPOpts []byte } +// Options contains options for creating a new test context. +type Options struct { + // EnableV4 indicates whether IPv4 should be enabled. + EnableV4 bool + + // EnableV6 indicates whether IPv4 should be enabled. + EnableV6 bool + + // MTU indicates the maximum transmission unit on the link layer. + MTU uint32 +} + // Context provides an initialized Network stack and a link layer endpoint // for use in TCP tests. type Context struct { @@ -154,10 +166,30 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + return NewWithOpts(t, Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, }) +} + +// NewWithOpts allocates and initializes a test context containing a new +// stack and a link-layer endpoint with specific options. +func NewWithOpts(t *testing.T, opts Options) *Context { + if opts.MTU == 0 { + panic("MTU must be greater than 0") + } + + stackOpts := stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + if opts.EnableV4 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) + } + if opts.EnableV6 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) + } + s := stack.New(stackOpts) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB @@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - ep := channel.New(1000, mtu, "") + ep := channel.New(1000, opts.MTU, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) + wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) } opts2 := stack.NICOptions{Name: "nic2"} if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } + var routeTable []tcpip.Route - s.SetRouteTable([]tcpip.Route{ - { + if opts.EnableV4 { + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, - }, - { + }) + } + + if opts.EnableV6 { + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, NIC: 1, - }, - }) + }) + } + + s.SetRouteTable(routeTable) return &Context{ t: t, diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cdb5127ab..9bcb918bb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } + if to.Port == 0 { + // Port 0 is an invalid port to send to. + return 0, nil, tcpip.ErrInvalidEndpointState + } + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -1012,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + if r.RequiresTXTransportChecksum() && (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { @@ -1382,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // On IPv4, UDP checksum is optional, and a zero value means the transmitter // omitted the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). -func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) +func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { + if !pkt.RXTransportChecksumValidated && + (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { + netHdr := pkt.Network() + xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) } @@ -1396,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { @@ -1406,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - if !verifyChecksum(r, hdr, pkt) { + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() @@ -1437,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, @@ -1447,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvBufSize += pkt.Data.Size() // Save any useful information from the network header to the packet. - switch r.NetProto { + switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: @@ -1457,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast // address. packetInfo.LocalAddr should hold a unicast address that can be // used to respond to the incoming packet. - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.LocalAddress - packet.packetInfo.NIC = r.NICID() + localAddr := pkt.Network().DestinationAddress() + packet.packetInfo.LocalAddr = localAddr + packet.packetInfo.DestinationAddr = localAddr + packet.packetInfo.NIC = pkt.NICID packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 3ae6cc221..14e4648cd 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, - route: r, id: id, pkt: pkt, }) @@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p // it via CreateEndpoint. type ForwarderRequest struct { stack *stack.Stack - route *stack.Route id stack.TransportEndpointID pkt *stack.PacketBuffer } @@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + netHdr := r.pkt.Network() + route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(r.pkt.SourceLinkAddress()) + + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() + route.Release() return nil, err } ep.ID = r.id - ep.route = r.route.Clone() + ep.route = route ep.dstPort = r.id.RemotePort - ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} - ep.RegisterNICID = r.route.NICID() + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} + ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.pkt) + ep.HandlePacket(r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index da5b1deb2..91420edd3 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets that are targeted at this // protocol but don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } - if !verifyChecksum(r, hdr, pkt) { - r.Stack().Stats().UDP.ChecksumErrors.Increment() + if !verifyChecksum(hdr, pkt) { + p.stack.Stats().UDP.ChecksumErrors.Increment() return stack.UnknownDestinationPacketMalformed } diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index 5c4b9e8e9..a38ffc19d 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -53,40 +53,40 @@ func randomFilename() (string, error) { func TestConnectFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") + t.Fatalf("Connect was successful, expected err") } } func TestBindFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") + t.Fatalf("Second bind succeeded, expected non-nil err") } } func TestMultipleAccept(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() @@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) { defer wg.Done() s, err := Connect(name, false) if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) + t.Errorf("Connect failed, got err %v expected nil", err) + return } s.Close() }() @@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) { for i := 0; i < backlog; i++ { s, err := ss.Accept() if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) + t.Errorf("Accept failed, got err %v expected nil", err) continue } s.Close() @@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) { func TestServerClose(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } // Make sure the first close succeeds. if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } // Bind a server. ss, err := BindAndListen(name, packet) if err != nil { - t.Fatalf("error binding, got %v expected nil", err) + t.Fatalf("Error binding, got %v expected nil", err) } defer ss.Close() @@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { // Connect the client. client, err := Connect(name, packet) if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) + t.Fatalf("Error connecting, got %v expected nil", err) } // Grab the server handle. @@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { case server := <-acceptSocket: return server, client case err := <-acceptErr: - t.Fatalf("accept error: %v", err) + t.Fatalf("Accept error: %v", err) } panic("unreachable") } @@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. b := [][]byte{{'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) { // Write on the server. w := server.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -233,13 +234,13 @@ func TestPacket(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Write on the client again. w = client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. @@ -249,19 +250,19 @@ func TestPacket(t *testing.T) { b := [][]byte{{'b', 'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Do it again. r = server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -271,12 +272,12 @@ func TestClose(t *testing.T) { // Make sure the first close succeeds. if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } @@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) { // We're good. That's what we wanted. blockCount++ } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err) } } } if blockCount == 1000 { // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") + t.Fatalf("Socket always blocked!") } else if blockCount == 0 { // Should have started blocking eventually. - t.Fatalf("socket never blocked!") + t.Fatalf("Socket never blocked!") } } @@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) { // Expected to block immediately. _, err := r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } // Put some data in the pipe. w := server.Writer(false) if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it not to block. if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it to return a block error again. r = client.Reader(false) _, err = r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } } @@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c'}, {'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) } } @@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c', 'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) } } @@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) { w := server.Writer(true) w.PackFDs(0, 1, 2) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client, without enabling FDs. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Make sure the FDs are not received. fds, err := r.ExtractFDs() if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) } } @@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) { w := s.Writer(true) w.PackFDs(fds...) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err) } } @@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Count the number of FDs. preEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } // Read on the client. @@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { r.EnableFDs(enableSize) } if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Count the new number of FDs. postEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) } // Make sure the FDs are there. fds, err := r.ExtractFDs() if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) } // Make sure they are different from the originals. for i := 0; i < len(fds); i++ { if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) + t.Errorf("Got original fd for index %d, expected different", i) } } @@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Make sure the count is back to normal. finalEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) } } @@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) { } if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) + t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) } } @@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) { want := "bad file descriptor" if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) + t.Errorf("s.GetPeerCred() = %v, want = %s", err, want) } } func TestAcceptClosed(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Fatalf("Close failed, got err %v expected nil", err) } if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } } func TestCloseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Errorf("Close failed, got err %v expected nil", err) } - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) { func TestReleaseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) fd, err := ss.Release() if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) + t.Errorf("Release failed, got err %v expected nil", err) } syscall.Close(fd) - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) { cm.PackFDs(want...) got, err := cm.ExtractFDs() if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) + t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) } } } @@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) { for i := 0; i < b.N; i++ { n, err := server.Read(buf) if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + return } n, err = server.Write(buf) if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + return } } }() diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index b97dc3c47..8c73dc5dc 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -75,6 +75,7 @@ go_library( "//pkg/sentry/platform", "//pkg/sentry/sighandling", "//pkg/sentry/socket/hostinet", + "//pkg/sentry/socket/netfilter", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netlink/route", "//pkg/sentry/socket/netlink/uevent", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 4e0f0d57a..fdf13c8e1 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -196,7 +196,7 @@ type containerManager struct { // StartRoot will start the root container process. func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error { - log.Debugf("containerManager.StartRoot %q", *cid) + log.Debugf("containerManager.StartRoot, cid: %s", *cid) // Tell the root container to start and wait for the result. cm.startChan <- struct{}{} if err := <-cm.startResultChan; err != nil { @@ -207,13 +207,13 @@ func (cm *containerManager) StartRoot(cid *string, _ *struct{}) error { // Processes retrieves information about processes running in the sandbox. func (cm *containerManager) Processes(cid *string, out *[]*control.Process) error { - log.Debugf("containerManager.Processes: %q", *cid) + log.Debugf("containerManager.Processes, cid: %s", *cid) return control.Processes(cm.l.k, *cid, out) } // Create creates a container within a sandbox. func (cm *containerManager) Create(cid *string, _ *struct{}) error { - log.Debugf("containerManager.Create: %q", *cid) + log.Debugf("containerManager.Create, cid: %s", *cid) return cm.l.createContainer(*cid) } @@ -237,12 +237,11 @@ type StartArgs struct { // Start runs a created container within a sandbox. func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { - log.Debugf("containerManager.Start: %+v", args) - // Validate arguments. if args == nil { return errors.New("start missing arguments") } + log.Debugf("containerManager.Start, cid: %s, args: %+v", args.CID, args) if args.Spec == nil { return errors.New("start arguments missing spec") } @@ -269,27 +268,27 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { } }() if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil { - log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err) + log.Debugf("containerManager.Start failed, cid: %s, args: %+v, err: %v", args.CID, args, err) return err } - log.Debugf("Container %q started", args.CID) + log.Debugf("Container started, cid: %s", args.CID) return nil } // Destroy stops a container if it is still running and cleans up its // filesystem. func (cm *containerManager) Destroy(cid *string, _ *struct{}) error { - log.Debugf("containerManager.destroy %q", *cid) + log.Debugf("containerManager.destroy, cid: %s", *cid) return cm.l.destroyContainer(*cid) } // ExecuteAsync starts running a command on a created or running sandbox. It // returns the PID of the new process. func (cm *containerManager) ExecuteAsync(args *control.ExecArgs, pid *int32) error { - log.Debugf("containerManager.ExecuteAsync: %+v", args) + log.Debugf("containerManager.ExecuteAsync, cid: %s, args: %+v", args.ContainerID, args) tgid, err := cm.l.executeAsync(args) if err != nil { - log.Debugf("containerManager.ExecuteAsync failed: %+v: %v", args, err) + log.Debugf("containerManager.ExecuteAsync failed, cid: %s, args: %+v, err: %v", args.ContainerID, args, err) return err } *pid = int32(tgid) @@ -453,9 +452,9 @@ func (cm *containerManager) Resume(_, _ *struct{}) error { // Wait waits for the init process in the given container. func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error { - log.Debugf("containerManager.Wait") + log.Debugf("containerManager.Wait, cid: %s", *cid) err := cm.l.waitContainer(*cid, waitStatus) - log.Debugf("containerManager.Wait returned, waitStatus: %v: %v", waitStatus, err) + log.Debugf("containerManager.Wait returned, cid: %s, waitStatus: %#x, err: %v", *cid, *waitStatus, err) return err } @@ -470,8 +469,10 @@ type WaitPIDArgs struct { // WaitPID waits for the process with PID 'pid' in the sandbox. func (cm *containerManager) WaitPID(args *WaitPIDArgs, waitStatus *uint32) error { - log.Debugf("containerManager.Wait") - return cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus) + log.Debugf("containerManager.Wait, cid: %s, pid: %d", args.CID, args.PID) + err := cm.l.waitPID(kernel.ThreadID(args.PID), args.CID, waitStatus) + log.Debugf("containerManager.Wait, cid: %s, pid: %d, waitStatus: %#x, err: %v", args.CID, args.PID, *waitStatus, err) + return err } // SignalDeliveryMode enumerates different signal delivery modes. @@ -528,6 +529,6 @@ type SignalArgs struct { // indicated process, to all processes in the container, or to the foreground // process group. func (cm *containerManager) Signal(args *SignalArgs, _ *struct{}) error { - log.Debugf("containerManager.Signal %+v", args) + log.Debugf("containerManager.Signal: cid: %s, PID: %d, signal: %d, mode: %v", args.CID, args.PID, args.Signo, args.Mode) return cm.l.signal(args.CID, args.PID, args.Signo, args.Mode) } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 10f0f452b..ebdd518d0 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -50,6 +50,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/pkg/sentry/sighandling" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/syscalls/linux/vfs2" "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -887,7 +888,7 @@ func (l *Loader) destroyContainer(cid string) error { } } - log.Debugf("Container destroyed %q", cid) + log.Debugf("Container destroyed, cid: %s", cid) return nil } @@ -1084,6 +1085,7 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in // privileges. RawFactory: raw.EndpointFactory{}, UniqueID: uniqueID, + IPTables: netfilter.DefaultLinuxTables(), })} // Enable SACK Recovery. diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index 8fe0c427a..c0bc8f064 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -75,7 +75,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } @@ -149,6 +149,9 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa } ws, err := cont.Wait() + if err != nil { + Fatalf("Error waiting for container: %v", err) + } *waitStatus = ws return subcommands.ExitSuccess diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 132198222..609e8231c 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -91,7 +91,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } var err error - c, err = container.Load(conf.RootDir, f.Arg(0)) + c, err = container.LoadAndCheck(conf.RootDir, f.Arg(0)) if err != nil { return Errorf("loading container %q: %v", f.Arg(0), err) } @@ -106,7 +106,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return Errorf("listing containers: %v", err) } for _, id := range ids { - candidate, err := container.Load(conf.RootDir, id) + candidate, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { return Errorf("loading container %q: %v", id, err) } diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go index 4e49deff8..a25637265 100644 --- a/runsc/cmd/delete.go +++ b/runsc/cmd/delete.go @@ -68,7 +68,7 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} func (d *Delete) execute(ids []string, conf *config.Config) error { for _, id := range ids { - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { if os.IsNotExist(err) && d.force { log.Warningf("couldn't find container %q: %v", id, err) diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 25fe2cf1c..3836b7b4e 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -74,7 +74,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } @@ -85,7 +85,12 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa ev, err := c.Event() if err != nil { log.Warningf("Error getting events for container: %v", err) + if evs.stats { + return subcommands.ExitFailure + } } + log.Debugf("Events: %+v", ev) + // err must be preserved because it is used below when breaking // out of the loop. b, err := json.Marshal(ev) @@ -101,11 +106,9 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa if err != nil { return subcommands.ExitFailure } - break + return subcommands.ExitSuccess } time.Sleep(time.Duration(evs.intervalSec) * time.Second) } - - return subcommands.ExitSuccess } diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index 775ed4b43..86c02a22a 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -112,7 +112,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } waitStatus := args[1].(*syscall.WaitStatus) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index 04eee99b2..fe69e2a08 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -69,7 +69,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("it is invalid to specify both --all and --pid") } - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go index f92d6fef9..6907eb16a 100644 --- a/runsc/cmd/list.go +++ b/runsc/cmd/list.go @@ -79,7 +79,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Collect the containers. var containers []*container.Container for _, id := range ids { - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container %q: %v", id, err) } diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go index 0eb1402ed..fe7d4e257 100644 --- a/runsc/cmd/pause.go +++ b/runsc/cmd/pause.go @@ -55,7 +55,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go index bc58c928f..18d7a1436 100644 --- a/runsc/cmd/ps.go +++ b/runsc/cmd/ps.go @@ -60,7 +60,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading sandbox: %v", err) } diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go index f24823f99..a00928204 100644 --- a/runsc/cmd/resume.go +++ b/runsc/cmd/resume.go @@ -56,7 +56,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} id := f.Arg(0) conf := args[0].(*config.Config) - cont, err := container.Load(conf.RootDir, id) + cont, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 139edbd49..f6499cc44 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -55,7 +55,7 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go index 2bd2ab9f8..d8a70dd7f 100644 --- a/runsc/cmd/state.go +++ b/runsc/cmd/state.go @@ -57,7 +57,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index 28d0642ed..c1d6aeae2 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -72,7 +72,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) id := f.Arg(0) conf := args[0].(*config.Config) - c, err := container.Load(conf.RootDir, id) + c, err := container.LoadAndCheck(conf.RootDir, id) if err != nil { Fatalf("loading container: %v", err) } diff --git a/runsc/config/flags.go b/runsc/config/flags.go index d3203b565..13d8f1b25 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -29,7 +29,7 @@ import ( var registration sync.Once -// This is the set of flags used to populate Config. +// RegisterFlags registers flags used to populate Config. func RegisterFlags() { registration.Do(func() { // Although these flags are not part of the OCI spec, they are used by diff --git a/runsc/container/container.go b/runsc/container/container.go index 52e1755ce..4aa139c88 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -159,9 +159,9 @@ func loadSandbox(rootDir, id string) ([]*Container, error) { // container to which id unambiguously refers to. Returns ErrNotExist if // container doesn't exist. func Load(rootDir, partialID string) (*Container, error) { - log.Debugf("Load container %q %q", rootDir, partialID) + log.Debugf("Load container, rootDir: %q, partial cid: %s", rootDir, partialID) if err := validateID(partialID); err != nil { - return nil, fmt.Errorf("validating id: %v", err) + return nil, fmt.Errorf("invalid container id: %v", err) } id, err := findContainerID(rootDir, partialID) @@ -184,22 +184,31 @@ func Load(rootDir, partialID string) (*Container, error) { } return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err) } + return c, nil +} + +// LoadAndCheck is similar to Load(), but also checks if the container is still +// running to get an error earlier to the caller. +func LoadAndCheck(rootDir, partialID string) (*Container, error) { + c, err := Load(rootDir, partialID) + if err != nil { + // Preserve error so that callers can distinguish 'not found' errors. + return nil, err + } - // If the status is "Running" or "Created", check that the sandbox - // process still exists, and set it to Stopped if it does not. + // If the status is "Running" or "Created", check that the sandbox/container + // is still running, setting it to Stopped if not. // // This is inherently racy. - if c.Status == Running || c.Status == Created { - // Check if the sandbox process is still running. + switch c.Status { + case Created: if !c.isSandboxRunning() { // Sandbox no longer exists, so this container definitely does not exist. c.changeStatus(Stopped) - } else if c.Status == Running { - // Container state should reflect the actual state of the application, so - // we don't consider gofer process here. - if err := c.SignalContainer(syscall.Signal(0), false); err != nil { - c.changeStatus(Stopped) - } + } + case Running: + if err := c.SignalContainer(syscall.Signal(0), false); err != nil { + c.changeStatus(Stopped) } } @@ -271,7 +280,7 @@ type Args struct { // indicates that an existing Sandbox should be used. The caller must call // Destroy() on the container. func New(conf *config.Config, args Args) (*Container, error) { - log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir) + log.Debugf("Create container, cid: %s, rootDir: %q", args.ID, conf.RootDir) if err := validateID(args.ID); err != nil { return nil, err } @@ -310,7 +319,7 @@ func New(conf *config.Config, args Args) (*Container, error) { // indicate the ID of the sandbox, which is the same as the ID of the // init container in the sandbox. if isRoot(args.Spec) { - log.Debugf("Creating new sandbox for container %q", args.ID) + log.Debugf("Creating new sandbox for container, cid: %s", args.ID) if args.Spec.Linux == nil { args.Spec.Linux = &specs.Linux{} @@ -380,10 +389,10 @@ func New(conf *config.Config, args Args) (*Container, error) { if !ok { return nil, fmt.Errorf("no sandbox ID found when creating container") } - log.Debugf("Creating new container %q in sandbox %q", c.ID, sbid) + log.Debugf("Creating new container, cid: %s, sandbox: %s", c.ID, sbid) // Find the sandbox associated with this ID. - sb, err := Load(conf.RootDir, sbid) + sb, err := LoadAndCheck(conf.RootDir, sbid) if err != nil { return nil, err } @@ -413,7 +422,7 @@ func New(conf *config.Config, args Args) (*Container, error) { // Start starts running the containerized process inside the sandbox. func (c *Container) Start(conf *config.Config) error { - log.Debugf("Start container %q", c.ID) + log.Debugf("Start container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -476,7 +485,7 @@ func (c *Container) Start(conf *config.Config) error { unlock.Clean() // Adjust the oom_score_adj for sandbox. This must be done after saveLocked(). - if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil { + if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Spec, c.Saver.RootDir, false); err != nil { return err } @@ -488,7 +497,7 @@ func (c *Container) Start(conf *config.Config) error { // Restore takes a container and replaces its kernel and file system // to restore a container from its state file. func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile string) error { - log.Debugf("Restore container %q", c.ID) + log.Debugf("Restore container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -515,7 +524,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile s // Run is a helper that calls Create + Start + Wait. func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { - log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir) + log.Debugf("Run container, cid: %s, rootDir: %q", args.ID, conf.RootDir) c, err := New(conf, args) if err != nil { return 0, fmt.Errorf("creating container: %v", err) @@ -547,7 +556,7 @@ func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { // Execute runs the specified command in the container. It returns the PID of // the newly created process. func (c *Container) Execute(args *control.ExecArgs) (int32, error) { - log.Debugf("Execute in container %q, args: %+v", c.ID, args) + log.Debugf("Execute in container, cid: %s, args: %+v", c.ID, args) if err := c.requireStatus("execute in", Created, Running); err != nil { return 0, err } @@ -557,7 +566,7 @@ func (c *Container) Execute(args *control.ExecArgs) (int32, error) { // Event returns events for the container. func (c *Container) Event() (*boot.Event, error) { - log.Debugf("Getting events for container %q", c.ID) + log.Debugf("Getting events for container, cid: %s", c.ID) if err := c.requireStatus("get events for", Created, Running, Paused); err != nil { return nil, err } @@ -577,14 +586,19 @@ func (c *Container) SandboxPid() int { // Call to wait on a stopped container is needed to retrieve the exit status // and wait returns immediately. func (c *Container) Wait() (syscall.WaitStatus, error) { - log.Debugf("Wait on container %q", c.ID) - return c.Sandbox.Wait(c.ID) + log.Debugf("Wait on container, cid: %s", c.ID) + ws, err := c.Sandbox.Wait(c.ID) + if err == nil { + // Wait succeeded, container is not running anymore. + c.changeStatus(Stopped) + } + return ws, err } // WaitRootPID waits for process 'pid' in the sandbox's PID namespace and // returns its WaitStatus. func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { - log.Debugf("Wait on PID %d in sandbox %q", pid, c.Sandbox.ID) + log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID) if !c.isSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } @@ -594,7 +608,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { // WaitPID waits for process 'pid' in the container's PID namespace and returns // its WaitStatus. func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { - log.Debugf("Wait on PID %d in container %q", pid, c.ID) + log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID) if !c.isSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") } @@ -606,7 +620,7 @@ func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { // SignalContainer returns an error if the container is already stopped. // TODO(b/113680494): Distinguish different error types. func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { - log.Debugf("Signal container %q: %v", c.ID, sig) + log.Debugf("Signal container, cid: %s, signal: %v (%d)", c.ID, sig, sig) // Signaling container in Stopped state is allowed. When all=false, // an error will be returned anyway; when all=true, this allows // sending signal to other processes inside the container even @@ -623,7 +637,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { // SignalProcess sends sig to a specific process in the container. func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { - log.Debugf("Signal process %d in container %q: %v", pid, c.ID, sig) + log.Debugf("Signal process %d in container, cid: %s, signal: %v (%d)", pid, c.ID, sig, sig) if err := c.requireStatus("signal a process inside", Running); err != nil { return err } @@ -637,15 +651,15 @@ func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { // container process inside the sandbox. It returns a function that will stop // forwarding signals. func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() { - log.Debugf("Forwarding all signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess) + log.Debugf("Forwarding all signals to container, cid: %s, PIDPID: %d, fgProcess: %t", c.ID, pid, fgProcess) stop := sighandling.StartSignalForwarding(func(sig linux.Signal) { - log.Debugf("Forwarding signal %d to container %q PID %d fgProcess=%t", sig, c.ID, pid, fgProcess) + log.Debugf("Forwarding signal %d to container, cid: %s, PID: %d, fgProcess: %t", sig, c.ID, pid, fgProcess) if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil { log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err) } }) return func() { - log.Debugf("Done forwarding signals to container %q PID %d fgProcess=%t", c.ID, pid, fgProcess) + log.Debugf("Done forwarding signals to container, cid: %s, PID: %d, fgProcess: %t", c.ID, pid, fgProcess) stop() } } @@ -653,7 +667,7 @@ func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() { // Checkpoint sends the checkpoint call to the container. // The statefile will be written to f, the file at the specified image-path. func (c *Container) Checkpoint(f *os.File) error { - log.Debugf("Checkpoint container %q", c.ID) + log.Debugf("Checkpoint container, cid: %s", c.ID) if err := c.requireStatus("checkpoint", Created, Running, Paused); err != nil { return err } @@ -663,7 +677,7 @@ func (c *Container) Checkpoint(f *os.File) error { // Pause suspends the container and its kernel. // The call only succeeds if the container's status is created or running. func (c *Container) Pause() error { - log.Debugf("Pausing container %q", c.ID) + log.Debugf("Pausing container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -674,7 +688,7 @@ func (c *Container) Pause() error { } if err := c.Sandbox.Pause(c.ID); err != nil { - return fmt.Errorf("pausing container: %v", err) + return fmt.Errorf("pausing container %q: %v", c.ID, err) } c.changeStatus(Paused) return c.saveLocked() @@ -683,7 +697,7 @@ func (c *Container) Pause() error { // Resume unpauses the container and its kernel. // The call only succeeds if the container's status is paused. func (c *Container) Resume() error { - log.Debugf("Resuming container %q", c.ID) + log.Debugf("Resuming container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err } @@ -722,7 +736,7 @@ func (c *Container) Processes() ([]*control.Process, error) { // Destroy stops all processes and frees all resources associated with the // container. func (c *Container) Destroy() error { - log.Debugf("Destroy container %q", c.ID) + log.Debugf("Destroy container, cid: %s", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -759,14 +773,12 @@ func (c *Container) Destroy() error { c.changeStatus(Stopped) // Adjust oom_score_adj for the sandbox. This must be done after the container - // is stopped and the directory at c.Root is removed. Adjustment can be - // skipped if the root container is exiting, because it brings down the entire - // sandbox. + // is stopped and the directory at c.Root is removed. // // Use 'sb' to tell whether it has been executed before because Destroy must // be idempotent. - if sb != nil && !isRoot(c.Spec) { - if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil { + if sb != nil { + if err := adjustSandboxOOMScoreAdj(sb, c.Spec, c.Saver.RootDir, true); err != nil { errs = append(errs, err.Error()) } } @@ -795,7 +807,7 @@ func (c *Container) Destroy() error { // // Precondition: container must be locked with container.lock(). func (c *Container) saveLocked() error { - log.Debugf("Save container %q", c.ID) + log.Debugf("Save container, cid: %s", c.ID) if err := c.Saver.saveLocked(c); err != nil { return fmt.Errorf("saving container metadata: %v", err) } @@ -809,7 +821,7 @@ func (c *Container) stop() error { var cgroup *cgroup.Cgroup if c.Sandbox != nil { - log.Debugf("Destroying container %q", c.ID) + log.Debugf("Destroying container, cid: %s", c.ID) if err := c.Sandbox.DestroyContainer(c.ID); err != nil { return fmt.Errorf("destroying container %q: %v", c.ID, err) } @@ -823,7 +835,7 @@ func (c *Container) stop() error { // Try killing gofer if it does not exit with container. if c.GoferPid != 0 { - log.Debugf("Killing gofer for container %q, PID: %d", c.ID, c.GoferPid) + log.Debugf("Killing gofer for container, cid: %s, PID: %d", c.ID, c.GoferPid) if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { // The gofer may already be stopped, log the error. log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err) @@ -1096,7 +1108,13 @@ func (c *Container) adjustGoferOOMScoreAdj() error { // TODO(gvisor.dev/issue/238): This call could race with other containers being // created at the same time and end up setting the wrong oom_score_adj to the // sandbox. Use rpc client to synchronize. -func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error { +func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, spec *specs.Spec, rootDir string, destroy bool) error { + // Adjustment can be skipped if the root container is exiting, because it + // brings down the entire sandbox. + if isRoot(spec) && destroy { + return nil + } + containers, err := loadSandbox(rootDir, s.ID) if err != nil { return fmt.Errorf("loading sandbox containers: %v", err) @@ -1110,53 +1128,34 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) // Get the lowest score for all containers. var lowScore int scoreFound := false - if len(containers) == 1 && specutils.SpecContainerType(containers[0].Spec) == specutils.ContainerTypeUnspecified { - // This is a single-container sandbox. Set the oom_score_adj to - // the value specified in the OCI bundle. - if containers[0].Spec.Process.OOMScoreAdj != nil { - scoreFound = true - lowScore = *containers[0].Spec.Process.OOMScoreAdj + for _, container := range containers { + // Special multi-container support for CRI. Ignore the root container when + // calculating oom_score_adj for the sandbox because it is the + // infrastructure (pause) container and always has a very low oom_score_adj. + // + // We will use OOMScoreAdj in the single-container case where the + // containerd container-type annotation is not present. + if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { + continue } - } else { - for _, container := range containers { - // Special multi-container support for CRI. Ignore the root - // container when calculating oom_score_adj for the sandbox because - // it is the infrastructure (pause) container and always has a very - // low oom_score_adj. - // - // We will use OOMScoreAdj in the single-container case where the - // containerd container-type annotation is not present. - if specutils.SpecContainerType(container.Spec) == specutils.ContainerTypeSandbox { - continue - } - if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { - scoreFound = true - lowScore = *container.Spec.Process.OOMScoreAdj - } + if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) { + scoreFound = true + lowScore = *container.Spec.Process.OOMScoreAdj } } // If the container is destroyed and remaining containers have no - // oomScoreAdj specified then we must revert to the oom_score_adj of the - // parent process. + // oomScoreAdj specified then we must revert to the original oom_score_adj + // saved with the root container. if !scoreFound && destroy { - ppid, err := specutils.GetParentPid(s.Pid) - if err != nil { - return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err) - } - pScore, err := specutils.GetOOMScoreAdj(ppid) - if err != nil { - return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err) - } - + lowScore = containers[0].Sandbox.OriginalOOMScoreAdj scoreFound = true - lowScore = pScore } - // Only set oom_score_adj if one of the containers has oom_score_adj set - // in the OCI bundle. If not, we need to inherit the parent process's - // oom_score_adj. + // Only set oom_score_adj if one of the containers has oom_score_adj set. If + // not, oom_score_adj is inherited from the parent process. + // // See: https://github.com/opencontainers/runtime-spec/blob/master/config.md#linux-process if !scoreFound { return nil diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index cc188f45b..fa99e403a 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -364,7 +364,7 @@ func TestLifecycle(t *testing.T) { defer c.Destroy() // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -387,7 +387,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -428,7 +428,7 @@ func TestLifecycle(t *testing.T) { } // Load the container from disk and check the status. - c, err = Load(rootDir, args.ID) + c, err = LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -451,7 +451,7 @@ func TestLifecycle(t *testing.T) { } // Loading the container by id should fail. - if _, err = Load(rootDir, args.ID); err == nil { + if _, err = LoadAndCheck(rootDir, args.ID); err == nil { t.Errorf("expected loading destroyed container to fail, but it did not") } }) @@ -1738,7 +1738,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { cids[2]: cids[2], } for shortid, longid := range unambiguous { - if _, err := Load(rootDir, shortid); err != nil { + if _, err := LoadAndCheck(rootDir, shortid); err != nil { t.Errorf("%q should resolve to %q: %v", shortid, longid, err) } } @@ -1749,7 +1749,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { "ba", } for _, shortid := range ambiguous { - if s, err := Load(rootDir, shortid); err == nil { + if s, err := LoadAndCheck(rootDir, shortid); err == nil { t.Errorf("%q should be ambiguous, but resolved to %q", shortid, s.ID) } } @@ -1976,11 +1976,11 @@ func doDestroyNotStartedTest(t *testing.T, vfs2 bool) { // TestDestroyStarting attempts to force a race between start and destroy. func TestDestroyStarting(t *testing.T) { - doDestroyNotStartedTest(t, false) + doDestroyStartingTest(t, false) } func TestDestroyStartedVFS2(t *testing.T) { - doDestroyNotStartedTest(t, true) + doDestroyStartingTest(t, true) } func doDestroyStartingTest(t *testing.T, vfs2 bool) { @@ -2007,7 +2007,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := Load(rootDir, args.ID) + startCont, err := LoadAndCheck(rootDir, args.ID) if err != nil { t.Fatalf("error loading container: %v", err) } diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index 850e80290..cadc63bf3 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -15,6 +15,7 @@ package container import ( + "encoding/json" "fmt" "io/ioutil" "math" @@ -762,7 +763,7 @@ func TestMultiContainerKillAll(t *testing.T) { // processes still running inside. containers[1].SignalContainer(syscall.SIGKILL, false) op := func() error { - c, err := Load(conf.RootDir, ids[1]) + c, err := LoadAndCheck(conf.RootDir, ids[1]) if err != nil { return err } @@ -776,7 +777,7 @@ func TestMultiContainerKillAll(t *testing.T) { } } - c, err := Load(conf.RootDir, ids[1]) + c, err := LoadAndCheck(conf.RootDir, ids[1]) if err != nil { t.Fatalf("failed to load child container %q: %v", c.ID, err) } @@ -899,7 +900,7 @@ func TestMultiContainerDestroyStarting(t *testing.T) { // Container is not thread safe, so load another instance to run in // concurrently. - startCont, err := Load(rootDir, ids[i]) + startCont, err := LoadAndCheck(rootDir, ids[i]) if err != nil { t.Fatalf("error loading container: %v", err) } @@ -1766,3 +1767,72 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { }) } } + +func TestMultiContainerEvent(t *testing.T) { + conf := testutil.TestConfig(t) + rootDir, cleanup, err := testutil.SetupRootDir() + if err != nil { + t.Fatalf("error creating root dir: %v", err) + } + defer cleanup() + conf.RootDir = rootDir + + // Setup the containers. + sleep := []string{"/bin/sleep", "100"} + quick := []string{"/bin/true"} + podSpec, ids := createSpecs(sleep, sleep, quick) + containers, cleanup, err := startContainers(conf, podSpec, ids) + if err != nil { + t.Fatalf("error starting containers: %v", err) + } + defer cleanup() + + for _, cont := range containers { + t.Logf("Running containerd %s", cont.ID) + } + + // Wait for last container to stabilize the process count that is checked + // further below. + if ws, err := containers[2].Wait(); err != nil || ws != 0 { + t.Fatalf("Container.Wait, status: %v, err: %v", ws, err) + } + + // Check events for running containers. + for _, cont := range containers[:2] { + evt, err := cont.Event() + if err != nil { + t.Errorf("Container.Events(): %v", err) + } + if want := "stats"; evt.Type != want { + t.Errorf("Wrong event type, want: %s, got :%s", want, evt.Type) + } + if cont.ID != evt.ID { + t.Errorf("Wrong container ID, want: %s, got :%s", cont.ID, evt.ID) + } + // Event.Data is an interface, so it comes from the wire was + // map[string]string. Marshal and unmarshall again to the correc type. + data, err := json.Marshal(evt.Data) + if err != nil { + t.Fatalf("invalid event data: %v", err) + } + var stats boot.Stats + if err := json.Unmarshal(data, &stats); err != nil { + t.Fatalf("invalid event data: %v", err) + } + // One process per remaining container. + if want := uint64(2); stats.Pids.Current != want { + t.Errorf("Wrong number of PIDs, want: %d, got :%d", want, stats.Pids.Current) + } + } + + // Check that stop and destroyed containers return error. + if err := containers[1].Destroy(); err != nil { + t.Fatalf("container.Destroy: %v", err) + } + for _, cont := range containers[1:] { + _, err := cont.Event() + if err == nil { + t.Errorf("Container.Events() should have failed, cid:%s, state: %v", cont.ID, cont.Status) + } + } +} diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index c4309feb3..4a4110477 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -66,6 +66,10 @@ type Sandbox struct { // Cgroup has the cgroup configuration for the sandbox. Cgroup *cgroup.Cgroup `json:"cgroup"` + // OriginalOOMScoreAdj stores the value of oom_score_adj when the sandbox + // started, before it may be modified. + OriginalOOMScoreAdj int `json:"originalOomScoreAdj"` + // child is set if a sandbox process is a child of the current process. // // This field isn't saved to json, because only a creator of sandbox @@ -739,6 +743,11 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn } return err } + s.OriginalOOMScoreAdj, err = specutils.GetOOMScoreAdj(cmd.Process.Pid) + if err != nil { + return err + } + s.child = true s.Pid = cmd.Process.Pid log.Infof("Sandbox started, PID: %d", s.Pid) @@ -1133,11 +1142,11 @@ func (s *Sandbox) DestroyContainer(cid string) error { func (s *Sandbox) destroyContainer(cid string) error { if s.IsRootContainer(cid) { - log.Debugf("Destroying root container %q by destroying sandbox", cid) + log.Debugf("Destroying root container by destroying sandbox, cid: %s", cid) return s.destroy() } - log.Debugf("Destroying container %q in sandbox %q", cid, s.ID) + log.Debugf("Destroying container, cid: %s, sandbox: %s", cid, s.ID) conn, err := s.sandboxConnect() if err != nil { return err diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 45abc1425..fdbba1832 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -19,6 +19,7 @@ package specutils import ( "encoding/json" "fmt" + "io" "io/ioutil" "os" "path" @@ -169,7 +170,7 @@ func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) { // ReadSpecFromFile reads an OCI runtime spec from the given File, and // normalizes all relative paths into absolute by prepending the bundle dir. func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) { - if _, err := specFile.Seek(0, os.SEEK_SET); err != nil { + if _, err := specFile.Seek(0, io.SeekStart); err != nil { return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err) } specBytes, err := ioutil.ReadAll(specFile) @@ -480,35 +481,6 @@ func GetOOMScoreAdj(pid int) (int, error) { return strconv.Atoi(strings.TrimSpace(string(data))) } -// GetParentPid gets the parent process ID of the specified PID. -func GetParentPid(pid int) (int, error) { - data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) - if err != nil { - return 0, err - } - - var cpid string - var name string - var state string - var ppid int - // Parse after the binary name. - _, err = fmt.Sscanf(string(data), - "%v %v %v %d", - // cpid is ignored. - &cpid, - // name is ignored. - &name, - // state is ignored. - &state, - &ppid) - - if err != nil { - return 0, err - } - - return ppid, nil -} - // EnvVar looks for a varible value in the env slice assuming the following // format: "NAME=VALUE". func EnvVar(env []string, name string) (string, bool) { diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD index 32c139204..7dfd4b693 100644 --- a/test/benchmarks/base/BUILD +++ b/test/benchmarks/base/BUILD @@ -13,7 +13,7 @@ go_library( go_test( name = "base_test", - size = "large", + size = "enormous", srcs = [ "size_test.go", "startup_test.go", diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc index 6f032cac1..73321592b 100644 --- a/test/fuse/linux/stat_test.cc +++ b/test/fuse/linux/stat_test.cc @@ -76,8 +76,13 @@ TEST_F(StatTest, StatNormal) { // Check filesystem operation result. struct stat expected_stat = { .st_ino = attr.ino, +#ifdef __aarch64__ + .st_mode = expected_mode, + .st_nlink = attr.nlink, +#else .st_nlink = attr.nlink, .st_mode = expected_mode, +#endif .st_uid = attr.uid, .st_gid = attr.gid, .st_rdev = attr.rdev, @@ -152,8 +157,13 @@ TEST_F(StatTest, FstatNormal) { // Check filesystem operation result. struct stat expected_stat = { .st_ino = attr.ino, +#ifdef __aarch64__ + .st_mode = expected_mode, + .st_nlink = attr.nlink, +#else .st_nlink = attr.nlink, .st_mode = expected_mode, +#endif .st_uid = attr.uid, .st_gid = attr.gid, .st_rdev = attr.rdev, diff --git a/test/iptables/README.md b/test/iptables/README.md index 28ab195ca..1196f8eb5 100644 --- a/test/iptables/README.md +++ b/test/iptables/README.md @@ -2,8 +2,38 @@ iptables tests are run via `make iptables-tests`. -iptables requires raw socket support, so you must add the `--net-raw=true` flag -to `/etc/docker/daemon.json` in order to use it. +iptables require some extra Docker configuration to work. Enable IPv6 in +`/etc/docker/daemon.json` (make sure to restart Docker if you change this file): + +```json +{ + "experimental": true, + "fixed-cidr-v6": "2001:db8:1::/64", + "ipv6": true, + // Runtimes and other Docker config... +} +``` + +And if you're running manually (i.e. not using the `make` target), you'll need +to: + +* Enable iptables via `modprobe iptables_filter && modprobe ip6table_filter`. +* Enable `--net-raw` in your chosen runtime in `/etc/docker/daemon.json` (make + sure to restart Docker if you change this file). + +The resulting runtime should look something like this: + +```json +"runsc": { + "path": "/tmp/iptables/runsc", + "runtimeArgs": [ + "--debug-log", + "/tmp/iptables/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%", + "--net-raw" + ] +}, +// ... +``` ## Test Structure diff --git a/test/packetimpact/netdevs/netdevs.go b/test/packetimpact/netdevs/netdevs.go index eecfe0730..006988896 100644 --- a/test/packetimpact/netdevs/netdevs.go +++ b/test/packetimpact/netdevs/netdevs.go @@ -40,7 +40,7 @@ var ( deviceLine = regexp.MustCompile(`^\s*(\d+): (\w+)`) linkLine = regexp.MustCompile(`^\s*link/\w+ ([0-9a-fA-F:]+)`) inetLine = regexp.MustCompile(`^\s*inet ([0-9./]+)`) - inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-Z:/]+)`) + inet6Line = regexp.MustCompile(`^\s*inet6 ([0-9a-fA-F:/]+)`) ) // ParseDevices parses the output from `ip addr show` into a map from device diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 1546d0d51..c03c2c62c 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -252,6 +252,9 @@ ALL_TESTS = [ expect_netstack_failure = True, ), PacketimpactTestInfo( + name = "ipv4_fragment_reassembly", + ), + PacketimpactTestInfo( name = "ipv6_fragment_reassembly", ), PacketimpactTestInfo( diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index a90046f69..8fa585804 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -839,6 +839,61 @@ func (conn *TCPIPv4) Drain(t *testing.T) { conn.sniffer.Drain(t) } +// IPv4Conn maintains the state for all the layers in a IPv4 connection. +type IPv4Conn Connection + +// NewIPv4Conn creates a new IPv4Conn connection with reasonable defaults. +func NewIPv4Conn(t *testing.T, outgoingIPv4, incomingIPv4 IPv4) IPv4Conn { + t.Helper() + + etherState, err := newEtherState(Ether{}, Ether{}) + if err != nil { + t.Fatalf("can't make EtherState: %s", err) + } + ipv4State, err := newIPv4State(outgoingIPv4, incomingIPv4) + if err != nil { + t.Fatalf("can't make IPv4State: %s", err) + } + + injector, err := NewInjector(t) + if err != nil { + t.Fatalf("can't make injector: %s", err) + } + sniffer, err := NewSniffer(t) + if err != nil { + t.Fatalf("can't make sniffer: %s", err) + } + + return IPv4Conn{ + layerStates: []layerState{etherState, ipv4State}, + injector: injector, + sniffer: sniffer, + } +} + +// Send sends a frame with ipv4 overriding the IPv4 layer defaults and +// additionalLayers added after it. +func (c *IPv4Conn) Send(t *testing.T, ipv4 IPv4, additionalLayers ...Layer) { + t.Helper() + + (*Connection)(c).send(t, Layers{&ipv4}, additionalLayers...) +} + +// Close cleans up any resources held. +func (c *IPv4Conn) Close(t *testing.T) { + t.Helper() + + (*Connection)(c).Close(t) +} + +// ExpectFrame expects a frame that matches the provided Layers within the +// timeout specified. If it doesn't arrive in time, an error is returned. +func (c *IPv4Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) { + t.Helper() + + return (*Connection)(c).ExpectFrame(t, frame, timeout) +} + // IPv6Conn maintains the state for all the layers in a IPv6 connection. type IPv6Conn Connection diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 8c2de5a9f..c30c77a17 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -298,6 +298,18 @@ packetimpact_testbench( ) packetimpact_testbench( + name = "ipv4_fragment_reassembly", + srcs = ["ipv4_fragment_reassembly_test.go"], + deps = [ + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( name = "ipv6_fragment_reassembly", srcs = ["ipv6_fragment_reassembly_test.go"], deps = [ @@ -305,6 +317,7 @@ packetimpact_testbench( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//test/packetimpact/testbench", + "@com_github_google_go_cmp//cmp:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/packetimpact/tests/ipv4_fragment_reassembly_test.go b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go new file mode 100644 index 000000000..65c0df140 --- /dev/null +++ b/test/packetimpact/tests/ipv4_fragment_reassembly_test.go @@ -0,0 +1,142 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_fragment_reassembly_test + +import ( + "flag" + "math/rand" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +type fragmentInfo struct { + offset uint16 + size uint16 + more uint8 +} + +func TestIPv4FragmentReassembly(t *testing.T) { + const fragmentID = 42 + icmpv4ProtoNum := uint8(header.ICMPv4ProtocolNumber) + + tests := []struct { + description string + ipPayloadLen int + fragments []fragmentInfo + expectReply bool + }{ + { + description: "basic reassembly", + ipPayloadLen: 2000, + fragments: []fragmentInfo{ + {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments}, + {offset: 1000, size: 1000, more: 0}, + }, + expectReply: true, + }, + { + description: "out of order fragments", + ipPayloadLen: 2000, + fragments: []fragmentInfo{ + {offset: 1000, size: 1000, more: 0}, + {offset: 0, size: 1000, more: header.IPv4FlagMoreFragments}, + }, + expectReply: true, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv4Conn(t, testbench.IPv4{}, testbench.IPv4{}) + defer conn.Close(t) + + data := make([]byte, test.ipPayloadLen) + icmp := header.ICMPv4(data[:header.ICMPv4MinimumSize]) + icmp.SetType(header.ICMPv4Echo) + icmp.SetCode(header.ICMPv4UnusedCode) + icmp.SetChecksum(0) + icmp.SetSequence(0) + icmp.SetIdent(0) + originalPayload := data[header.ICMPv4MinimumSize:] + if _, err := rand.Read(originalPayload); err != nil { + t.Fatalf("rand.Read: %s", err) + } + cksum := header.ICMPv4Checksum( + icmp, + buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), + ) + icmp.SetChecksum(cksum) + + for _, fragment := range test.fragments { + conn.Send(t, + testbench.IPv4{ + Protocol: &icmpv4ProtoNum, + FragmentOffset: testbench.Uint16(fragment.offset), + Flags: testbench.Uint8(fragment.more), + ID: testbench.Uint16(fragmentID), + }, + &testbench.Payload{ + Bytes: data[fragment.offset:][:fragment.size], + }) + } + + var bytesReceived int + reassembledPayload := make([]byte, test.ipPayloadLen) + for { + incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv4{}, + &testbench.ICMPv4{}, + }, time.Second) + if err != nil { + // Either an unexpected frame was received, or none at all. + if bytesReceived < test.ipPayloadLen { + t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err) + } + break + } + if !test.expectReply { + t.Fatalf("unexpected reply received:\n%s", incomingFrame) + } + ipPayload, err := incomingFrame[2 /* ICMPv4 */].ToBytes() + if err != nil { + t.Fatalf("failed to parse ICMPv4 header: incomingPacket[2].ToBytes() = (_, %s)", err) + } + offset := *incomingFrame[1 /* IPv4 */].(*testbench.IPv4).FragmentOffset + if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) { + t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload)) + } + bytesReceived += len(ipPayload) + } + + if test.expectReply { + if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv4MinimumSize:]); diff != "" { + t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go index a24c85566..4a29de688 100644 --- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -15,154 +15,137 @@ package ipv6_fragment_reassembly_test import ( - "bytes" - "encoding/binary" - "encoding/hex" "flag" + "math/rand" "net" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/test/packetimpact/testbench" ) -const ( - // The payload length for the first fragment we send. This number - // is a multiple of 8 near 750 (half of 1500). - firstPayloadLength = 752 - // The ID field for our outgoing fragments. - fragmentID = 1 - // A node must be able to accept a fragmented packet that, - // after reassembly, is as large as 1500 octets. - reassemblyCap = 1500 -) - func init() { testbench.RegisterFlags(flag.CommandLine) } -func TestIPv6FragmentReassembly(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) - defer conn.Close(t) - - firstPayloadToSend := make([]byte, firstPayloadLength) - for i := range firstPayloadToSend { - firstPayloadToSend[i] = 'A' - } - - secondPayloadLength := reassemblyCap - firstPayloadLength - header.ICMPv6EchoMinimumSize - secondPayloadToSend := firstPayloadToSend[:secondPayloadLength] - - icmpv6EchoPayload := make([]byte, 4) - binary.BigEndian.PutUint16(icmpv6EchoPayload[0:], 0) - binary.BigEndian.PutUint16(icmpv6EchoPayload[2:], 0) - icmpv6EchoPayload = append(icmpv6EchoPayload, firstPayloadToSend...) - - lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) - rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) - icmpv6 := testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - Payload: icmpv6EchoPayload, - } - icmpv6Bytes, err := icmpv6.ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6: %s", err) - } - cksum := header.ICMPv6Checksum( - header.ICMPv6(icmpv6Bytes), - lIP, - rIP, - buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}), - ) - - conn.Send(t, testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - FragmentOffset: testbench.Uint16(0), - MoreFragments: testbench.Bool(true), - Identification: testbench.Uint32(fragmentID), - }, - &testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - Payload: icmpv6EchoPayload, - Checksum: &cksum, - }) +type fragmentInfo struct { + offset uint16 + size uint16 + more bool +} +func TestIPv6FragmentReassembly(t *testing.T) { + const fragmentID = 42 icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber) - conn.Send(t, testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - NextHeader: &icmpv6ProtoNum, - FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8), - MoreFragments: testbench.Bool(false), - Identification: testbench.Uint32(fragmentID), + tests := []struct { + description string + ipPayloadLen int + fragments []fragmentInfo + expectReply bool + }{ + { + description: "basic reassembly", + ipPayloadLen: 1500, + fragments: []fragmentInfo{ + {offset: 0, size: 760, more: true}, + {offset: 760, size: 740, more: false}, + }, + expectReply: true, }, - &testbench.Payload{ - Bytes: secondPayloadToSend, - }) - - gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{ - &testbench.Ether{}, - &testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - FragmentOffset: testbench.Uint16(0), - MoreFragments: testbench.Bool(true), + { + description: "out of order fragments", + ipPayloadLen: 3000, + fragments: []fragmentInfo{ + {offset: 0, size: 1024, more: true}, + {offset: 2048, size: 952, more: false}, + {offset: 1024, size: 1024, more: true}, + }, + expectReply: true, }, - &testbench.ICMPv6{ - Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), - Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), - }, - }, time.Second) - if err != nil { - t.Fatalf("expected a fragmented ICMPv6 Echo Reply, but got none: %s", err) } - id := *gotEchoReplyFirstPart[2].(*testbench.IPv6FragmentExtHdr).Identification - gotFirstPayload, err := gotEchoReplyFirstPart[len(gotEchoReplyFirstPart)-1].ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6: %s", err) - } - icmpPayload := gotFirstPayload[header.ICMPv6EchoMinimumSize:] - receivedLen := len(icmpPayload) - wantSecondPayloadLen := reassemblyCap - header.ICMPv6EchoMinimumSize - receivedLen - wantFirstPayload := make([]byte, receivedLen) - for i := range wantFirstPayload { - wantFirstPayload[i] = 'A' - } - wantSecondPayload := wantFirstPayload[:wantSecondPayloadLen] - if !bytes.Equal(icmpPayload, wantFirstPayload) { - t.Fatalf("received unexpected payload, got: %s, want: %s", - hex.Dump(icmpPayload), - hex.Dump(wantFirstPayload)) - } - - gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{ - &testbench.Ether{}, - &testbench.IPv6{}, - &testbench.IPv6FragmentExtHdr{ - NextHeader: &icmpv6ProtoNum, - FragmentOffset: testbench.Uint16(uint16((receivedLen + header.ICMPv6EchoMinimumSize) / 8)), - MoreFragments: testbench.Bool(false), - Identification: &id, - }, - &testbench.ICMPv6{}, - }, time.Second) - if err != nil { - t.Fatalf("expected the rest of ICMPv6 Echo Reply, but got none: %s", err) - } - secondPayload, err := gotEchoReplySecondPart[len(gotEchoReplySecondPart)-1].ToBytes() - if err != nil { - t.Fatalf("failed to serialize ICMPv6 Echo Reply: %s", err) - } - if !bytes.Equal(secondPayload, wantSecondPayload) { - t.Fatalf("received unexpected payload, got: %s, want: %s", - hex.Dump(secondPayload), - hex.Dump(wantSecondPayload)) + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{}) + defer conn.Close(t) + + lIP := tcpip.Address(net.ParseIP(testbench.LocalIPv6).To16()) + rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) + + data := make([]byte, test.ipPayloadLen) + icmp := header.ICMPv6(data[:header.ICMPv6HeaderSize]) + icmp.SetType(header.ICMPv6EchoRequest) + icmp.SetCode(header.ICMPv6UnusedCode) + icmp.SetChecksum(0) + originalPayload := data[header.ICMPv6HeaderSize:] + if _, err := rand.Read(originalPayload); err != nil { + t.Fatalf("rand.Read: %s", err) + } + + cksum := header.ICMPv6Checksum( + icmp, + lIP, + rIP, + buffer.NewVectorisedView(len(originalPayload), []buffer.View{originalPayload}), + ) + icmp.SetChecksum(cksum) + + for _, fragment := range test.fragments { + conn.Send(t, testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{ + NextHeader: &icmpv6ProtoNum, + FragmentOffset: testbench.Uint16(fragment.offset / header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit), + MoreFragments: testbench.Bool(fragment.more), + Identification: testbench.Uint32(fragmentID), + }, + &testbench.Payload{ + Bytes: data[fragment.offset:][:fragment.size], + }) + } + + var bytesReceived int + reassembledPayload := make([]byte, test.ipPayloadLen) + for { + incomingFrame, err := conn.ExpectFrame(t, testbench.Layers{ + &testbench.Ether{}, + &testbench.IPv6{}, + &testbench.IPv6FragmentExtHdr{}, + &testbench.ICMPv6{}, + }, time.Second) + if err != nil { + // Either an unexpected frame was received, or none at all. + if bytesReceived < test.ipPayloadLen { + t.Fatalf("received %d bytes out of %d, then conn.ExpectFrame(_, _, time.Second) failed with %s", bytesReceived, test.ipPayloadLen, err) + } + break + } + if !test.expectReply { + t.Fatalf("unexpected reply received:\n%s", incomingFrame) + } + ipPayload, err := incomingFrame[3 /* ICMPv6 */].ToBytes() + if err != nil { + t.Fatalf("failed to parse ICMPv6 header: incomingPacket[3].ToBytes() = (_, %s)", err) + } + offset := *incomingFrame[2 /* IPv6FragmentExtHdr */].(*testbench.IPv6FragmentExtHdr).FragmentOffset + offset *= header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit + if copied := copy(reassembledPayload[offset:], ipPayload); copied != len(ipPayload) { + t.Fatalf("wrong number of bytes copied into reassembledPayload: got = %d, want = %d", copied, len(ipPayload)) + } + bytesReceived += len(ipPayload) + } + + if test.expectReply { + if diff := cmp.Diff(originalPayload, reassembledPayload[header.ICMPv6HeaderSize:]); diff != "" { + t.Fatalf("reassembledPayload mismatch (-want +got):\n%s", diff) + } + } + }) } } diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go index 4243eb59e..0dcc0fdea 100644 --- a/test/root/oom_score_adj_test.go +++ b/test/root/oom_score_adj_test.go @@ -40,11 +40,7 @@ var ( // TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a // single container sandbox. func TestOOMScoreAdjSingle(t *testing.T) { - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid()) if err != nil { t.Fatalf("getting parent oom_score_adj: %v", err) } @@ -122,11 +118,7 @@ func TestOOMScoreAdjSingle(t *testing.T) { // TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a // multi-container sandbox. func TestOOMScoreAdjMulti(t *testing.T) { - ppid, err := specutils.GetParentPid(os.Getpid()) - if err != nil { - t.Fatalf("getting parent pid: %v", err) - } - parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid) + parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(os.Getppid()) if err != nil { t.Fatalf("getting parent oom_score_adj: %v", err) } diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 9b5994d59..4992147d4 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -97,10 +97,10 @@ def _syscall_test( # we figure out how to request ipv4 sockets on Guitar machines. if network == "host": tags.append("noguitar") - tags.append("block-network") # Disable off-host networking. tags.append("requires-net:loopback") + tags.append("block-network") # gotsan makes sense only if tests are running in gVisor. if platform == "native": diff --git a/test/runtimes/exclude/php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index f984a579a..c051fe571 100644 --- a/test/runtimes/exclude/php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv @@ -8,6 +8,7 @@ ext/mbstring/tests/bug77165.phpt,, ext/mbstring/tests/bug77454.phpt,, ext/mbstring/tests/mb_convert_encoding_leak.phpt,, ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/pcre/tests/cache_limit.phpt,,Broken test - Flaky ext/session/tests/session_module_name_variation4.phpt,,Flaky ext/session/tests/session_set_save_handler_class_018.phpt,, ext/session/tests/session_set_save_handler_iface_003.phpt,, @@ -29,7 +30,7 @@ ext/standard/tests/file/php_fd_wrapper_04.phpt,, ext/standard/tests/file/realpath_bug77484.phpt,b/162894969,VFS1 only failure ext/standard/tests/file/rename_variation.phpt,b/68717309, ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223,VFS1 only failure ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv index 911f22855..e9fef03b7 100644 --- a/test/runtimes/exclude/python3.7.3.csv +++ b/test/runtimes/exclude/python3.7.3.csv @@ -4,7 +4,6 @@ test_asyncore,b/162973328, test_epoll,b/162983393, test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. test_httplib,b/163000009,OSError: [Errno 98] Address already in use -test_imaplib,b/162979661, test_logging,b/162980079, test_multiprocessing_fork,,Flaky. Sometimes times out. test_multiprocessing_forkserver,,Flaky. Sometimes times out. diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go index 78285cb0e..64e6e14db 100644 --- a/test/runtimes/runner/lib/lib.go +++ b/test/runtimes/runner/lib/lib.go @@ -34,8 +34,16 @@ import ( // RunTests is a helper that is called by main. It exists so that we can run // defered functions before exiting. It returns an exit code that should be // passed to os.Exit. -func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int { - // Get tests to exclude.. +func RunTests(lang, image, excludeFile string, partitionNum, totalPartitions, batchSize int, timeout time.Duration) int { + if partitionNum <= 0 || totalPartitions <= 0 || partitionNum > totalPartitions { + fmt.Fprintf(os.Stderr, "invalid partition %d of %d", partitionNum, totalPartitions) + return 1 + } + + // TODO(gvisor.dev/issue/1624): Remove those tests from all exclude lists + // that only fail with VFS1. + + // Get tests to exclude. excludes, err := getExcludes(excludeFile) if err != nil { fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) @@ -55,7 +63,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat // Get a slice of tests to run. This will also start a single Docker // container that will be used to run each test. The final test will // stop the Docker container. - tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes) + tests, err := getTests(ctx, d, lang, image, partitionNum, totalPartitions, batchSize, timeout, excludes) if err != nil { fmt.Fprintf(os.Stderr, "%s\n", err.Error()) return 1 @@ -66,7 +74,7 @@ func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Durat } // getTests executes all tests as table tests. -func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { +func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, partitionNum, totalPartitions, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { // Start the container. opts := dockerutil.RunOpts{ Image: fmt.Sprintf("runtimes/%s", image), @@ -86,6 +94,14 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, // shard. tests := strings.Fields(list) sort.Strings(tests) + + partitionSize := len(tests) / totalPartitions + if partitionNum == totalPartitions { + tests = tests[(partitionNum-1)*partitionSize:] + } else { + tests = tests[(partitionNum-1)*partitionSize : partitionNum*partitionSize] + } + indices, err := testutil.TestIndicesForShard(len(tests)) if err != nil { return nil, fmt.Errorf("TestsForShard() failed: %v", err) @@ -116,8 +132,15 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, err error ) + state, err := d.Status(ctx) + if err != nil { + t.Fatalf("Could not find container status: %v", err) + } + if !state.Running { + t.Fatalf("container is not running: state = %s", state.Status) + } + go func() { - fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ",")) close(done) }() @@ -125,12 +148,12 @@ func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, select { case <-done: if err == nil { - fmt.Printf("PASS: (%v)\n\n", time.Since(now)) + fmt.Printf("PASS: (%v) %d tests passed\n", time.Since(now), len(tcs)) return } - t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) + t.Errorf("FAIL: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output) case <-time.After(timeout): - t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) + t.Errorf("TIMEOUT: (%v):\nBatch:\n%s\nOutput:\n%s\n", time.Since(now), strings.Join(tcs, "\n"), output) } }, }) diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index ec79a22c2..5b3443e36 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -25,11 +25,13 @@ import ( ) var ( - lang = flag.String("lang", "", "language runtime to test") - image = flag.String("image", "", "docker image with runtime tests") - excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") - batchSize = flag.Int("batch", 50, "number of test cases run in one command") - timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") + lang = flag.String("lang", "", "language runtime to test") + image = flag.String("image", "", "docker image with runtime tests") + excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") + partition = flag.Int("partition", 1, "partition number, this is 1-indexed") + totalPartitions = flag.Int("total_partitions", 1, "total number of partitions") + batchSize = flag.Int("batch", 50, "number of test cases run in one command") + timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") ) func main() { @@ -38,5 +40,5 @@ func main() { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout)) + os.Exit(lib.RunTests(*lang, *image, *excludeFile, *partition, *totalPartitions, *batchSize, *timeout)) } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index f66a9ceb4..b5a4ef4df 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -695,6 +695,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:socket_ip_unbound_netlink_test", +) + +syscall_test( test = "//test/syscalls/linux:socket_netdevice_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index c94c1d5bd..2350f7e69 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1802,10 +1802,14 @@ cc_binary( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/time", gtest, + "//test/util:file_descriptor", + "//test/util:fs_util", "//test/util:logging", + "//test/util:memory_util", "//test/util:multiprocess_util", "//test/util:platform_util", "//test/util:signal_util", + "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", "//test/util:time_util", @@ -2102,10 +2106,12 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2125,6 +2131,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2138,10 +2145,12 @@ cc_binary( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", gtest, + "//test/util:signal_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) @@ -2880,6 +2889,24 @@ cc_binary( ) cc_binary( + name = "socket_ip_unbound_netlink_test", + testonly = 1, + srcs = [ + "socket_ip_unbound_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_domain_test", testonly = 1, srcs = [ diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc index 926690eb8..13c19d4a8 100644 --- a/test/syscalls/linux/ptrace.cc +++ b/test/syscalls/linux/ptrace.cc @@ -30,10 +30,13 @@ #include "absl/flags/flag.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "test/util/fs_util.h" #include "test/util/logging.h" +#include "test/util/memory_util.h" #include "test/util/multiprocess_util.h" #include "test/util/platform_util.h" #include "test/util/signal_util.h" +#include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" #include "test/util/time_util.h" @@ -113,10 +116,21 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { // except disabled. SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(YamaPtraceScope()) > 0); - constexpr long kBeforePokeDataValue = 10; - constexpr long kAfterPokeDataValue = 20; + // Test PTRACE_POKE/PEEKDATA on both anonymous and file mappings. + const auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + ASSERT_NO_ERRNO(Truncate(file.path(), kPageSize)); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + const auto file_mapping = ASSERT_NO_ERRNO_AND_VALUE(Mmap( + nullptr, kPageSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd.get(), 0)); - volatile long word = kBeforePokeDataValue; + constexpr long kBeforePokeDataAnonValue = 10; + constexpr long kAfterPokeDataAnonValue = 20; + constexpr long kBeforePokeDataFileValue = 0; // implicit, due to truncate() + constexpr long kAfterPokeDataFileValue = 30; + + volatile long anon_word = kBeforePokeDataAnonValue; + auto* file_word_ptr = static_cast<volatile long*>(file_mapping.ptr()); pid_t const child_pid = fork(); if (child_pid == 0) { @@ -134,12 +148,22 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { MaybeSave(); TEST_CHECK(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP); - // Replace the value of word in the parent process with kAfterPokeDataValue. - long const parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &word, 0); + // Replace the value of anon_word in the parent process with + // kAfterPokeDataAnonValue. + long parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, &anon_word, 0); + MaybeSave(); + TEST_CHECK(parent_word == kBeforePokeDataAnonValue); + TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, &anon_word, + kAfterPokeDataAnonValue) == 0); + MaybeSave(); + + // Replace the value pointed to by file_word_ptr in the mapped file with + // kAfterPokeDataFileValue, via the parent process' mapping. + parent_word = ptrace(PTRACE_PEEKDATA, parent_pid, file_word_ptr, 0); MaybeSave(); - TEST_CHECK(parent_word == kBeforePokeDataValue); - TEST_PCHECK( - ptrace(PTRACE_POKEDATA, parent_pid, &word, kAfterPokeDataValue) == 0); + TEST_CHECK(parent_word == kBeforePokeDataFileValue); + TEST_PCHECK(ptrace(PTRACE_POKEDATA, parent_pid, file_word_ptr, + kAfterPokeDataFileValue) == 0); MaybeSave(); // Detach from the parent and suppress the SIGSTOP. If the SIGSTOP is not @@ -160,7 +184,8 @@ TEST(PtraceTest, AttachParent_PeekData_PokeData_SignalSuppression) { << " status " << status; // Check that the child's PTRACE_POKEDATA was effective. - EXPECT_EQ(kAfterPokeDataValue, word); + EXPECT_EQ(kAfterPokeDataAnonValue, anon_word); + EXPECT_EQ(kAfterPokeDataFileValue, *file_word_ptr); } TEST(PtraceTest, GetSigMask) { diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index ed6a1c2aa..890f4a246 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <signal.h> #include <sys/ipc.h> #include <sys/sem.h> #include <sys/types.h> @@ -542,6 +543,236 @@ TEST(SemaphoreTest, SemCtlIpcStat) { SyscallFailsWithErrno(EACCES)); } +// Calls semctl(semid, 0, cmd) until the returned value is >= target, an +// internal timeout expires, or semctl returns an error. +PosixErrorOr<int> WaitSemctl(int semid, int target, int cmd) { + constexpr absl::Duration timeout = absl::Seconds(10); + const auto deadline = absl::Now() + timeout; + int semcnt = 0; + while (absl::Now() < deadline) { + semcnt = semctl(semid, 0, cmd); + if (semcnt < 0) { + return PosixError(errno, "semctl(GETZCNT) failed"); + } + if (semcnt >= target) { + break; + } + absl::SleepFor(absl::Milliseconds(10)); + } + return semcnt; +} + +TEST(SemaphoreTest, SemopGetzcnt) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + // Create a write only semaphore set. + AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + // No read permission to retrieve semzcnt. + EXPECT_THAT(semctl(sem.get(), 0, GETZCNT), SyscallFailsWithErrno(EACCES)); + + // Remove the calling thread's read permission. + struct semid_ds ds = {}; + ds.sem_perm.uid = getuid(); + ds.sem_perm.gid = getgid(); + ds.sem_perm.mode = 0600; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds()); + + std::vector<pid_t> children; + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); + + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + constexpr size_t kLoops = 10; + for (auto i = 0; i < kLoops; i++) { + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); + _exit(0); + } + children.push_back(child_pid); + } + + EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETZCNT), + IsPosixErrorOkAndHolds(kLoops)); + // Set semval to 0, which wakes up children that sleep on the semop. + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 0), SyscallSucceeds()); + for (const auto& child_pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0); +} + +TEST(SemaphoreTest, SemopGetzcntOnSetRemoval) { + auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT); + ASSERT_THAT(semid, SyscallSucceeds()); + ASSERT_THAT(semctl(semid, 0, SETVAL, 1), SyscallSucceeds()); + ASSERT_EQ(semctl(semid, 0, GETZCNT), 0); + + auto child_pid = fork(); + if (child_pid == 0) { + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + + // Ensure that wait will only unblock when the semaphore is removed. On + // EINTR retry it may race with deletion and return EINVAL. + TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(semid, 1, GETZCNT), IsPosixErrorOkAndHolds(1)); + // Remove the semaphore set, which fails the sleep semop. + ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds()); + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_THAT(semctl(semid, 0, GETZCNT), SyscallFailsWithErrno(EINVAL)); +} + +TEST(SemaphoreTest, SemopGetzcntOnSignal_NoRandomSave) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds()); + ASSERT_EQ(semctl(sem.get(), 0, GETZCNT), 0); + + // Saving will cause semop() to be spuriously interrupted. + DisableSave ds; + + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR); + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = 0; + + TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(sem.get(), 1, GETZCNT), IsPosixErrorOkAndHolds(1)); + // Send a signal to the child, which fails the sleep semop. + ASSERT_EQ(kill(child_pid, SIGHUP), 0); + + ds.reset(); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_EQ(semctl(sem.get(), 0, GETZCNT), 0); +} + +TEST(SemaphoreTest, SemopGetncnt) { + // Drop CAP_IPC_OWNER which allows us to bypass semaphore permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_IPC_OWNER, false)); + // Create a write only semaphore set. + AutoSem sem(semget(IPC_PRIVATE, 1, 0200 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + + // No read permission to retrieve semzcnt. + EXPECT_THAT(semctl(sem.get(), 0, GETNCNT), SyscallFailsWithErrno(EACCES)); + + // Remove the calling thread's read permission. + struct semid_ds ds = {}; + ds.sem_perm.uid = getuid(); + ds.sem_perm.gid = getgid(); + ds.sem_perm.mode = 0600; + ASSERT_THAT(semctl(sem.get(), 0, IPC_SET, &ds), SyscallSucceeds()); + + std::vector<pid_t> children; + + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + constexpr size_t kLoops = 10; + for (auto i = 0; i < kLoops; i++) { + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(RetryEINTR(semop)(sem.get(), &buf, 1) == 0); + _exit(0); + } + children.push_back(child_pid); + } + EXPECT_THAT(WaitSemctl(sem.get(), kLoops, GETNCNT), + IsPosixErrorOkAndHolds(kLoops)); + // Set semval to 1, which wakes up children that sleep on the semop. + ASSERT_THAT(semctl(sem.get(), 0, SETVAL, kLoops), SyscallSucceeds()); + for (const auto& child_pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0); +} + +TEST(SemaphoreTest, SemopGetncntOnSetRemoval) { + auto semid = semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT); + ASSERT_THAT(semid, SyscallSucceeds()); + ASSERT_EQ(semctl(semid, 0, GETNCNT), 0); + + auto child_pid = fork(); + if (child_pid == 0) { + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + + // Ensure that wait will only unblock when the semaphore is removed. On + // EINTR retry it may race with deletion and return EINVAL + TEST_PCHECK(RetryEINTR(semop)(semid, &buf, 1) < 0 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + EXPECT_THAT(WaitSemctl(semid, 1, GETNCNT), IsPosixErrorOkAndHolds(1)); + // Remove the semaphore set, which fails the sleep semop. + ASSERT_THAT(semctl(semid, 0, IPC_RMID), SyscallSucceeds()); + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_THAT(semctl(semid, 0, GETNCNT), SyscallFailsWithErrno(EINVAL)); +} + +TEST(SemaphoreTest, SemopGetncntOnSignal_NoRandomSave) { + AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); + ASSERT_THAT(sem.get(), SyscallSucceeds()); + ASSERT_EQ(semctl(sem.get(), 0, GETNCNT), 0); + + // Saving will cause semop() to be spuriously interrupted. + DisableSave ds; + + auto child_pid = fork(); + if (child_pid == 0) { + TEST_PCHECK(signal(SIGHUP, [](int sig) -> void {}) != SIG_ERR); + struct sembuf buf = {}; + buf.sem_num = 0; + buf.sem_op = -1; + + TEST_PCHECK(semop(sem.get(), &buf, 1) < 0 && errno == EINTR); + _exit(0); + } + EXPECT_THAT(WaitSemctl(sem.get(), 1, GETNCNT), IsPosixErrorOkAndHolds(1)); + // Send a signal to the child, which fails the sleep semop. + ASSERT_EQ(kill(child_pid, SIGHUP), 0); + + ds.reset(); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + EXPECT_EQ(semctl(sem.get(), 0, GETNCNT), 0); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index a8bfb01f1..cf0977118 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -25,9 +25,11 @@ #include "absl/time/time.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/timer_util.h" namespace gvisor { namespace testing { @@ -629,6 +631,57 @@ TEST(SendFileTest, SendFileToPipe) { SyscallSucceedsWithValue(kDataSize)); } +static volatile int signaled = 0; +void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } + +TEST(SendFileTest, ToEventFDDoesNotSpin_NoRandomSave) { + FileDescriptor efd = ASSERT_NO_ERRNO_AND_VALUE(NewEventFD(0, 0)); + + // Write the maximum value of an eventfd to a file. + const uint64_t kMaxEventfdValue = 0xfffffffffffffffe; + const auto tempfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const auto tempfd = ASSERT_NO_ERRNO_AND_VALUE(Open(tempfile.path(), O_RDWR)); + ASSERT_THAT( + pwrite(tempfd.get(), &kMaxEventfdValue, sizeof(kMaxEventfdValue), 0), + SyscallSucceedsWithValue(sizeof(kMaxEventfdValue))); + + // Set the eventfd's value to 1. + const uint64_t kOne = 1; + ASSERT_THAT(write(efd.get(), &kOne, sizeof(kOne)), + SyscallSucceedsWithValue(sizeof(kOne))); + + // Set up signal handler. + struct sigaction sa = {}; + sa.sa_sigaction = SigUsr1Handler; + sa.sa_flags = SA_SIGINFO; + const auto cleanup_sigact = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); + + // Send SIGUSR1 to this thread in 1 second. + struct sigevent sev = {}; + sev.sigev_notify = SIGEV_THREAD_ID; + sev.sigev_signo = SIGUSR1; + sev.sigev_notify_thread_id = gettid(); + auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); + struct itimerspec its = {}; + its.it_value = absl::ToTimespec(absl::Seconds(1)); + DisableSave ds; // Asserting an EINTR. + ASSERT_NO_ERRNO(timer.Set(0, its)); + + // Sendfile from tempfd to the eventfd. Since the eventfd is not already at + // its maximum value, the eventfd is "ready for writing"; however, since the + // eventfd's existing value plus the new value would exceed the maximum, the + // write should internally fail with EWOULDBLOCK. In this case, sendfile() + // should block instead of spinning, and eventually be interrupted by our + // timer. See b/172075629. + EXPECT_THAT( + sendfile(efd.get(), tempfd.get(), nullptr, sizeof(kMaxEventfdValue)), + SyscallFailsWithErrno(EINTR)); + + // Signal should have been handled. + EXPECT_EQ(signaled, 1); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 39a68c5a5..e19a83413 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -940,7 +940,7 @@ void setupTimeWaitClose(const TestAddress* listener, } // shutdown to trigger TIME_WAIT. - ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds()); + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_WR), SyscallSucceeds()); { const int kTimeout = 10000; struct pollfd pfd = { @@ -950,7 +950,8 @@ void setupTimeWaitClose(const TestAddress* listener, ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); ASSERT_EQ(pfd.revents, POLLIN); } - ScopedThread t([&]() { + ASSERT_THAT(shutdown(passive_closefd.get(), SHUT_WR), SyscallSucceeds()); + { constexpr int kTimeout = 10000; constexpr int16_t want_events = POLLHUP; struct pollfd pfd = { @@ -958,11 +959,8 @@ void setupTimeWaitClose(const TestAddress* listener, .events = want_events, }; ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - }); + } - passive_closefd.reset(); - t.Join(); - active_closefd.reset(); // This sleep is needed to reduce flake to ensure that the passive-close // ensures the state transitions to CLOSE from LAST_ACK. absl::SleepFor(absl::Seconds(1)); diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 8f7ccc868..029f1e872 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -454,23 +454,15 @@ TEST_P(IPUnboundSocketTest, SetReuseAddr) { INSTANTIATE_TEST_SUITE_P( IPUnboundSockets, IPUnboundSocketTest, - ::testing::ValuesIn(VecCat<SocketKind>(VecCat<SocketKind>( + ::testing::ValuesIn(VecCat<SocketKind>( ApplyVec<SocketKind>(IPv4UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector<int>{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv6UDPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_DGRAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector<int>{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv4TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{0, - SOCK_NONBLOCK})), + std::vector{0, SOCK_NONBLOCK}), ApplyVec<SocketKind>(IPv6TCPUnboundSocket, - AllBitwiseCombinations(List<int>{SOCK_STREAM}, - List<int>{ - 0, SOCK_NONBLOCK})))))); + std::vector{0, SOCK_NONBLOCK})))); } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound_netlink.cc b/test/syscalls/linux/socket_ip_unbound_netlink.cc new file mode 100644 index 000000000..6036bfcaf --- /dev/null +++ b/test/syscalls/linux/socket_ip_unbound_netlink.cc @@ -0,0 +1,104 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of IP sockets. +using IPv6UnboundSocketTest = SimpleSocketTest; + +TEST_P(IPv6UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved + // across save/restore. + DisableSave ds; + + // Delete the loopback address from the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/128, &in6addr_loopback, + sizeof(in6addr_loopback))); + Cleanup defer_addr_removal = + Cleanup([loopback_link = std::move(loopback_link)] { + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/128, &in6addr_loopback, + sizeof(in6addr_loopback))); + }); + + TestAddress addr = V6Loopback(); + reinterpret_cast<sockaddr_in6*>(&addr.addr)->sin6_port = 65535; + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv6UnboundSocketTest, + ::testing::ValuesIn(std::vector<SocketKind>{ + IPv6UDPUnboundSocket(0), + IPv6TCPUnboundSocket(0)})); + +using IPv4UnboundSocketTest = SimpleSocketTest; + +TEST_P(IPv4UnboundSocketTest, ConnectToBadLocalAddress_NoRandomSave) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // TODO(gvisor.dev/issue/4595): Addresses on net devices are not saved + // across save/restore. + DisableSave ds; + + // Delete the loopback address from the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr laddr; + laddr.s_addr = htonl(INADDR_LOOPBACK); + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/8, &laddr, sizeof(laddr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(laddr)] { + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/8, &addr, sizeof(addr))); + }); + TestAddress addr = V4Loopback(); + reinterpret_cast<sockaddr_in*>(&addr.addr)->sin_port = 65535; + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(connect(sock->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +INSTANTIATE_TEST_SUITE_P(IPUnboundSockets, IPv4UnboundSocketTest, + ::testing::ValuesIn(std::vector<SocketKind>{ + IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0)})); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index e83f0d81f..ee3c08770 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -536,7 +536,7 @@ TEST(NetlinkRouteTest, AddAndRemoveAddr) { // Second delete should fail, as address no longer exists. EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EINVAL, ::testing::_)); + PosixErrorIs(EADDRNOTAVAIL, ::testing::_)); }); // Replace an existing address should succeed. diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index a1d2b9b11..c2369db54 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -26,9 +26,11 @@ #include "absl/time/clock.h" #include "absl/time/time.h" #include "test/util/file_descriptor.h" +#include "test/util/signal_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/timer_util.h" namespace gvisor { namespace testing { @@ -772,6 +774,59 @@ TEST(SpliceTest, FromPipeToDevZero) { SyscallSucceedsWithValue(0)); } +static volatile int signaled = 0; +void SigUsr1Handler(int sig, siginfo_t* info, void* context) { signaled = 1; } + +TEST(SpliceTest, ToPipeWithSmallCapacityDoesNotSpin_NoRandomSave) { + // Writes to a pipe that are less than PIPE_BUF must be atomic. This test + // creates a pipe with only 128 bytes of capacity (< PIPE_BUF) and checks that + // splicing to the pipe does not spin. See b/170743336. + + // Create a file with one page of data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), absl::string_view(buf.data(), buf.size()), + TempPath::kDefaultFileMode)); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + // Create a pipe with size 4096, and fill all but 128 bytes of it. + int p[2]; + ASSERT_THAT(pipe(p), SyscallSucceeds()); + ASSERT_THAT(fcntl(p[1], F_SETPIPE_SZ, kPageSize), SyscallSucceeds()); + const int kWriteSize = kPageSize - 128; + std::vector<char> writeBuf(kWriteSize); + RandomizeBuffer(writeBuf.data(), writeBuf.size()); + ASSERT_THAT(write(p[1], writeBuf.data(), writeBuf.size()), + SyscallSucceedsWithValue(kWriteSize)); + + // Set up signal handler. + struct sigaction sa = {}; + sa.sa_sigaction = SigUsr1Handler; + sa.sa_flags = SA_SIGINFO; + const auto cleanup_sigact = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGUSR1, sa)); + + // Send SIGUSR1 to this thread in 1 second. + struct sigevent sev = {}; + sev.sigev_notify = SIGEV_THREAD_ID; + sev.sigev_signo = SIGUSR1; + sev.sigev_notify_thread_id = gettid(); + auto timer = ASSERT_NO_ERRNO_AND_VALUE(TimerCreate(CLOCK_MONOTONIC, sev)); + struct itimerspec its = {}; + its.it_value = absl::ToTimespec(absl::Seconds(1)); + DisableSave ds; // Asserting an EINTR. + ASSERT_NO_ERRNO(timer.Set(0, its)); + + // Now splice the file to the pipe. This should block, but not spin, and + // should return EINTR because it is interrupted by the signal. + EXPECT_THAT(splice(fd.get(), nullptr, p[1], nullptr, kPageSize, 0), + SyscallFailsWithErrno(EINTR)); + + // Alarm should have been handled. + EXPECT_EQ(signaled, 1); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index cac94d9e1..93a98adb1 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -322,11 +322,6 @@ TEST(IntervalTimerTest, PeriodicGroupDirectedSignal) { EXPECT_GE(counted_signals.load(), kCycles); } -// From Linux's include/uapi/asm-generic/siginfo.h. -#ifndef sigev_notify_thread_id -#define sigev_notify_thread_id _sigev_un._tid -#endif - TEST(IntervalTimerTest, PeriodicThreadDirectedSignal) { constexpr int kSigno = SIGUSR1; constexpr int kSigvalue = 42; diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index bc5bd9218..d65275fd3 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -1887,6 +1887,22 @@ TEST_P(UdpSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } +TEST_P(UdpSocketTest, SendToZeroPort) { + char buf[8]; + struct sockaddr_storage addr = InetLoopbackAddr(); + + // Sending to an invalid port should fail. + SetPort(&addr, 0); + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallFailsWithErrno(EINVAL)); + + SetPort(&addr, 1234); + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, + reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)), + SyscallSucceedsWithValue(sizeof(buf))); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/util/timer_util.h b/test/util/timer_util.h index 926e6632f..e389108ef 100644 --- a/test/util/timer_util.h +++ b/test/util/timer_util.h @@ -33,6 +33,11 @@ namespace gvisor { namespace testing { +// From Linux's include/uapi/asm-generic/siginfo.h. +#ifndef sigev_notify_thread_id +#define sigev_notify_thread_id _sigev_un._tid +#endif + // Returns the current time. absl::Time Now(clockid_t id); diff --git a/tools/bazel.mk b/tools/bazel.mk index 88431ce66..3a7de427f 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -26,13 +26,13 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ BUILD_ROOTS := bazel-bin/ bazel-out/ # Bazel container configuration (see below). -USER ?= gvisor -HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) +USER := $(shell whoami) +HASH := $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8) BUILDER_BASE := gvisor.dev/images/default BUILDER_IMAGE := gvisor.dev/images/builder -BUILDER_NAME ?= gvisor-builder-$(HASH) -DOCKER_NAME ?= gvisor-bazel-$(HASH) -DOCKER_PRIVILEGED ?= --privileged +BUILDER_NAME := gvisor-builder-$(HASH) +DOCKER_NAME := gvisor-bazel-$(HASH) +DOCKER_PRIVILEGED := --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock @@ -59,6 +59,25 @@ ifeq (true,$(shell [[ -t 0 ]] && echo true)) FULL_DOCKER_EXEC_OPTIONS += --tty endif +# Add basic UID/GID options. +# +# Note that USERADD_DOCKER and GROUPADD_DOCKER are both defined as "deferred" +# variables in Make terminology, that is they will be expanded at time of use +# and may include other variables, including those defined below. +# +# NOTE: we pass -l to useradd below because otherwise you can hit a bug +# best described here: +# https://github.com/moby/moby/issues/5419#issuecomment-193876183 +# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out +# out of disk space. +ifneq ($(UID),0) +USERADD_DOCKER += useradd -l --uid $(UID) --non-unique --no-create-home \ + --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && +endif +ifneq ($(GID),0) +GROUPADD_DOCKER += groupadd --gid $(GID) --non-unique $(USER) && +endif + # Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" @@ -91,19 +110,12 @@ ifneq (,$(BAZEL_CONFIG)) OPTIONS += --config=$(BAZEL_CONFIG) endif -# NOTE: we pass -l to useradd below because otherwise you can hit a bug -# best described here: -# https://github.com/moby/moby/issues/5419#issuecomment-193876183 -# TLDR; trying to add to /var/log/lastlog (sparse file) runs the machine out -# out of disk space. bazel-image: load-default @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \ $(BUILDER_BASE) \ - sh -c "groupadd --gid $(GID) --non-unique $(USER) && \ - $(GROUPADD_DOCKER) \ - useradd -l --uid $(UID) --non-unique --no-create-home \ - --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \ + sh -c "$(GROUPADD_DOCKER) \ + $(USERADD_DOCKER) \ if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi" docker commit $(BUILDER_NAME) $(BUILDER_IMAGE) @docker rm -f $(BUILDER_NAME) diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 34b270cc0..544af3876 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -105,7 +105,7 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { } // NewBenchmark initializes a new benchmark. -func NewBenchmark(name string, iters int, official bool) *Benchmark { +func NewBenchmark(name string, iters int) *Benchmark { return &Benchmark{ Name: name, Metric: make([]*Metric, 0), @@ -113,12 +113,13 @@ func NewBenchmark(name string, iters int, official bool) *Benchmark { } // NewSuite initializes a new Suite. -func NewSuite(name string) *Suite { +func NewSuite(name string, official bool) *Suite { return &Suite{ Name: name, Timestamp: time.Now().UTC(), Benchmarks: make([]*Benchmark, 0), Conditions: make([]*Condition, 0), + Official: official, } } diff --git a/tools/go_branch.sh b/tools/go_branch.sh index 100c4ae41..71d036b12 100755 --- a/tools/go_branch.sh +++ b/tools/go_branch.sh @@ -61,14 +61,6 @@ go_branch=$( \ ) readonly go_branch -declare commit -commit=$(git rev-parse HEAD) -readonly commit -if [[ -n "$(git branch --contains="${commit}" go)" ]]; then - # The go branch already has the commit. - exit 0 -fi - # Clone the current repository to the temporary directory, and check out the # current go_branch directory. We move to the new repository for convenience. declare repo_orig @@ -117,7 +109,7 @@ EOF # There are a few solitary files that can get left behind due to the way bazel # constructs the gopath target. Note that we don't find all Go files here # because they may correspond to unused templates, etc. -declare -ar binaries=( "runsc" "shim/v1" "shim/v2" ) +declare -ar binaries=( "runsc" "shim/v1" "shim/v2" "webhook" ) for target in "${binaries[@]}"; do mkdir -p "${target}" cp "${repo_orig}/${target}"/*.go "${target}/" @@ -130,7 +122,11 @@ find . -type f -exec chmod 0644 {} \; find . -type d -exec chmod 0755 {} \; # Update the current working set and commit. -git add . && git commit -m "Merge ${head} (automated)" +# If the current working commit has already been committed to the remote go +# branch, then we have nothing to commit here. So allow empty commit. This can +# occur when this script is run parallely (via pull_request and push events) +# and the push workflow finishes before the pull_request workflow can run this. +git add . && git commit --allow-empty -m "Merge ${head} (automated)" # Push the branch back to the original repository. git remote add orig "${repo_orig}" && git push -f orig go:go diff --git a/tools/go_generics/imports.go b/tools/go_generics/imports.go index 90d3aa1e0..370650e46 100644 --- a/tools/go_generics/imports.go +++ b/tools/go_generics/imports.go @@ -48,7 +48,7 @@ func updateImportIdent(orig string, imports mapValue, id *ast.Ident, used map[st // Create a new entry in the used map. path := imports[importName] if path == "" { - return fmt.Errorf("Unknown path to package '%s', used in '%s'", importName, orig) + return fmt.Errorf("unknown path to package '%s', used in '%s'", importName, orig) } m = &importedPackage{ @@ -72,7 +72,7 @@ func convertExpression(s string, imports mapValue, used map[string]*importedPack // Parse the expression in the input string. expr, err := parser.ParseExpr(s) if err != nil { - return "", fmt.Errorf("Unable to parse \"%s\": %v", s, err) + return "", fmt.Errorf("unable to parse \"%s\": %v", s, err) } // Go through the AST and update references. diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go index 7f62b0a2b..df14ae98e 100644 --- a/tools/go_marshal/test/escape/escape.go +++ b/tools/go_marshal/test/escape/escape.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package escape contains test cases for escape analysis. package escape import ( diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go index d9e9f341b..e7e3ed74a 100644 --- a/tools/go_marshal/test/test.go +++ b/tools/go_marshal/test/test.go @@ -161,7 +161,7 @@ type TestArray [sizeA]int32 // +marshal type TestArray2 [sizeA * sizeB]int32 -// TestArray2 is a newtype on an array with a simple arithmetic expression of +// TestArray3 is a newtype on an array with a simple arithmetic expression of // mixed constants and literals for the array length. // // +marshal diff --git a/tools/parsers/BUILD b/tools/parsers/BUILD index dab954e25..6932bba9a 100644 --- a/tools/parsers/BUILD +++ b/tools/parsers/BUILD @@ -31,8 +31,12 @@ go_library( go_binary( name = "parser", testonly = 1, - srcs = ["parser_main.go"], + srcs = [ + "parser_main.go", + "version.go", + ], nogo = False, + x_defs = {"main.version": "{STABLE_VERSION}"}, deps = [ ":parsers", "//runsc/flag", diff --git a/tools/parsers/go_parser.go b/tools/parsers/go_parser.go index df4875e6a..57e538149 100644 --- a/tools/parsers/go_parser.go +++ b/tools/parsers/go_parser.go @@ -30,10 +30,10 @@ import ( // ParseOutput expects golang benchmark output and returns a struct formatted // for BigQuery. func ParseOutput(output string, name string, official bool) (*bigquery.Suite, error) { - suite := bigquery.NewSuite(name) + suite := bigquery.NewSuite(name, official) lines := strings.Split(output, "\n") for _, line := range lines { - bm, err := parseLine(line, official) + bm, err := parseLine(line) if err != nil { return nil, fmt.Errorf("failed to parse line '%s': %v", line, err) } @@ -60,7 +60,7 @@ func ParseOutput(output string, name string, official bool) (*bigquery.Suite, er // {Name: requests_per_second, Unit: QPS, Sample: 140 } // } //} -func parseLine(line string, official bool) (*bigquery.Benchmark, error) { +func parseLine(line string) (*bigquery.Benchmark, error) { fields := strings.Fields(line) // Check if this line is a Benchmark line. Otherwise ignore the line. @@ -78,7 +78,7 @@ func parseLine(line string, official bool) (*bigquery.Benchmark, error) { return nil, fmt.Errorf("parse name/params: %v", err) } - bm := bigquery.NewBenchmark(name, iters, official) + bm := bigquery.NewBenchmark(name, iters) for _, p := range params { bm.AddCondition(p.Name, p.Value) } diff --git a/tools/parsers/go_parser_test.go b/tools/parsers/go_parser_test.go index 0aa1152a2..f0737d46b 100644 --- a/tools/parsers/go_parser_test.go +++ b/tools/parsers/go_parser_test.go @@ -94,7 +94,7 @@ func TestParseLine(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := parseLine(tc.data, false) + got, err := parseLine(tc.data) if err != nil { t.Fatalf("parseLine failed with: %v", err) } diff --git a/tools/parsers/parser_main.go b/tools/parsers/parser_main.go index 6c6182464..7cce69e03 100644 --- a/tools/parsers/parser_main.go +++ b/tools/parsers/parser_main.go @@ -49,12 +49,11 @@ var ( parseCmd = flag.NewFlagSet(parseString, flag.ContinueOnError) file = parseCmd.String("file", "", "file to parse for benchmarks") name = parseCmd.String("suite_name", "", "name of the benchmark suite") - clNumber = parseCmd.String("cl", "", "changelist number of this run") - gitCommit = parseCmd.String("git_commit", "", "git commit sha for this run") parseProject = parseCmd.String("project", "", "GCP project to send benchmarks.") parseDataset = parseCmd.String("dataset", "", "dataset to send benchmarks data.") parseTable = parseCmd.String("table", "", "table to send benchmarks data.") official = parseCmd.Bool("official", false, "mark input data as official.") + runtime = parseCmd.String("runtime", "", "runtime used to run the benchmark") ) // initBenchmarks initializes a dataset/table in a BigQuery project. @@ -67,23 +66,29 @@ func initBenchmarks(ctx context.Context) error { func parseBenchmarks(ctx context.Context) error { data, err := ioutil.ReadFile(*file) if err != nil { - return fmt.Errorf("failed to read file: %v", err) + return fmt.Errorf("failed to read file %s: %v", *file, err) } suite, err := parsers.ParseOutput(string(data), *name, *official) if err != nil { return fmt.Errorf("failed parse data: %v", err) } + if len(suite.Benchmarks) < 1 { + fmt.Fprintf(os.Stderr, "Failed to find benchmarks for file: %s", *file) + return nil + } + extraConditions := []*bq.Condition{ { - Name: "change_list", - Value: *clNumber, + Name: "runtime", + Value: *runtime, }, { - Name: "commit", - Value: *gitCommit, + Name: "version", + Value: version, }, } + suite.Official = *official suite.Conditions = append(suite.Conditions, extraConditions...) return bq.SendBenchmarks(ctx, suite, *parseProject, *parseDataset, *parseTable, nil) } @@ -94,26 +99,27 @@ func main() { // the "init" command case len(os.Args) >= 2 && os.Args[1] == initString: if err := initCmd.Parse(os.Args[2:]); err != nil { - fmt.Fprintf(os.Stderr, "failed parse flags: %v", err) + fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err) os.Exit(1) } if err := initBenchmarks(ctx); err != nil { - failure := "failed to initialize project: %s dataset: %s table: %s: %v" + failure := "failed to initialize project: %s dataset: %s table: %s: %v\n" fmt.Fprintf(os.Stderr, failure, *parseProject, *parseDataset, *parseTable, err) os.Exit(1) } // the "parse" command. case len(os.Args) >= 2 && os.Args[1] == parseString: if err := parseCmd.Parse(os.Args[2:]); err != nil { - fmt.Fprintf(os.Stderr, "failed parse flags: %v", err) + fmt.Fprintf(os.Stderr, "failed parse flags: %v\n", err) os.Exit(1) } if err := parseBenchmarks(ctx); err != nil { - fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v", err) + fmt.Fprintf(os.Stderr, "failed parse benchmarks: %v\n", err) os.Exit(1) } default: printUsage() + os.Exit(1) } } diff --git a/tools/parsers/version.go b/tools/parsers/version.go new file mode 100644 index 000000000..ab9194b9d --- /dev/null +++ b/tools/parsers/version.go @@ -0,0 +1,18 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +// version is set during linking. +var version = "VERSION_MISSING" diff --git a/webhook/BUILD b/webhook/BUILD new file mode 100644 index 000000000..33c585504 --- /dev/null +++ b/webhook/BUILD @@ -0,0 +1,28 @@ +load("//images:defs.bzl", "docker_image") +load("//tools:defs.bzl", "go_binary", "pkg_tar") + +package(licenses = ["notice"]) + +docker_image( + name = "webhook_image", + data = ":files", + statements = ['ENTRYPOINT ["/webhook"]'], +) + +# files is the full file system of the webhook container. It is simply: +# / +# └─ webhook +pkg_tar( + name = "files", + srcs = [":webhook"], + extension = "tgz", + strip_prefix = "/third_party/gvisor/webhook", +) + +go_binary( + name = "webhook", + srcs = ["main.go"], + pure = "on", + static = "on", + deps = ["//webhook/pkg/cli"], +) diff --git a/webhook/main.go b/webhook/main.go new file mode 100644 index 000000000..220016543 --- /dev/null +++ b/webhook/main.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary main serves a mutating Kubernetes webhook. +package main + +import ( + "gvisor.dev/gvisor/webhook/pkg/cli" +) + +func main() { + cli.Main() +} diff --git a/webhook/pkg/cli/BUILD b/webhook/pkg/cli/BUILD new file mode 100644 index 000000000..ac093c556 --- /dev/null +++ b/webhook/pkg/cli/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "cli", + srcs = ["cli.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "//webhook/pkg/injector", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_apimachinery//pkg/util/net:go_default_library", + "@io_k8s_client_go//kubernetes:go_default_library", + "@io_k8s_client_go//rest:go_default_library", + ], +) diff --git a/webhook/pkg/cli/cli.go b/webhook/pkg/cli/cli.go new file mode 100644 index 000000000..a07d341a2 --- /dev/null +++ b/webhook/pkg/cli/cli.go @@ -0,0 +1,115 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli provides a CLI interface for a mutating Kubernetes webhook. +package cli + +import ( + "flag" + "fmt" + "net" + "net/http" + "os" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/webhook/pkg/injector" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8snet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +var ( + address = flag.String("address", "", "The ip address the admission webhook serves on. If unspecified, a public address is selected automatically.") + port = flag.Int("port", 0, "The port the admission webhook serves on.") + podLabels = flag.String("pod-namespace-labels", "", "A comma-separated namespace label selector, the admission webhook will only take effect on pods in selected namespaces, e.g. `label1,label2`.") +) + +// Main runs the webhook. +func Main() { + flag.Parse() + + if err := run(); err != nil { + log.Warningf("%v", err) + os.Exit(1) + } +} + +func run() error { + log.Infof("Starting %s\n", injector.Name) + + // Create client config. + cfg, err := rest.InClusterConfig() + if err != nil { + return fmt.Errorf("create in cluster config: %w", err) + } + + // Create clientset. + clientset, err := kubernetes.NewForConfig(cfg) + if err != nil { + return fmt.Errorf("create kubernetes client: %w", err) + } + + if err := injector.CreateConfiguration(clientset, parsePodLabels()); err != nil { + return fmt.Errorf("create webhook configuration: %w", err) + } + + if err := startWebhookHTTPS(clientset); err != nil { + return fmt.Errorf("start webhook https server: %w", err) + } + + return nil +} + +func parsePodLabels() *metav1.LabelSelector { + rv := &metav1.LabelSelector{} + for _, s := range strings.Split(*podLabels, ",") { + req := metav1.LabelSelectorRequirement{ + Key: strings.TrimSpace(s), + Operator: "Exists", + } + rv.MatchExpressions = append(rv.MatchExpressions, req) + } + return rv +} + +func startWebhookHTTPS(clientset kubernetes.Interface) error { + log.Infof("Starting HTTPS handler") + defer log.Infof("Stopping HTTPS handler") + + if *address == "" { + ip, err := k8snet.ChooseHostInterface() + if err != nil { + return fmt.Errorf("select ip address: %w", err) + } + *address = ip.String() + } + mux := http.NewServeMux() + mux.Handle("/", http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + injector.Admit(w, r) + })) + server := &http.Server{ + // Listen on all addresses. + Addr: net.JoinHostPort(*address, strconv.Itoa(*port)), + TLSConfig: injector.GetTLSConfig(), + Handler: mux, + } + if err := server.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + return fmt.Errorf("start HTTPS handler: %w", err) + } + return nil +} diff --git a/webhook/pkg/injector/BUILD b/webhook/pkg/injector/BUILD new file mode 100644 index 000000000..d296981be --- /dev/null +++ b/webhook/pkg/injector/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "injector", + srcs = [ + "certs.go", + "webhook.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/log", + "@com_github_mattbaird_jsonpatch//:go_default_library", + "@io_k8s_api//admission/v1beta1:go_default_library", + "@io_k8s_api//admissionregistration/v1beta1:go_default_library", + "@io_k8s_api//core/v1:go_default_library", + "@io_k8s_apimachinery//pkg/api/errors:go_default_library", + "@io_k8s_apimachinery//pkg/apis/meta/v1:go_default_library", + "@io_k8s_client_go//kubernetes:go_default_library", + ], +) + +genrule( + name = "certs", + srcs = [":gencerts"], + outs = ["certs.go"], + cmd = "$$(cut -d ' ' -f 1 <<< \"$(locations :gencerts)\") $@", +) + +sh_binary( + name = "gencerts", + srcs = ["gencerts.sh"], +) diff --git a/webhook/pkg/injector/gencerts.sh b/webhook/pkg/injector/gencerts.sh new file mode 100755 index 000000000..f7fda4b63 --- /dev/null +++ b/webhook/pkg/injector/gencerts.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Generates the a CA cert, a server key, and a server cert signed by the CA. +# reference: +# https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/admission/plugin/webhook/testcerts/gencerts.sh +set -euo pipefail + +# Do all the work in TMPDIR, then copy out generated code and delete TMPDIR. +declare -r OUTDIR="$(readlink -e .)" +declare -r TMPDIR="$(mktemp -d)" +cd "${TMPDIR}" +function cleanup() { + cd "${OUTDIR}" + rm -rf "${TMPDIR}" +} +trap cleanup EXIT + +declare -r CN_BASE="e2e" +declare -r CN="gvisor-injection-admission-webhook.e2e.svc" + +cat > server.conf << EOF +[req] +req_extensions = v3_req +distinguished_name = req_distinguished_name +[req_distinguished_name] +[ v3_req ] +basicConstraints = CA:FALSE +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +extendedKeyUsage = clientAuth, serverAuth +EOF + +declare -r OUTFILE="${TMPDIR}/certs.go" + +# We depend on OpenSSL being present. + +# Create a certificate authority. +openssl genrsa -out caKey.pem 2048 +openssl req -x509 -new -nodes -key caKey.pem -days 100000 -out caCert.pem -subj "/CN=${CN_BASE}_ca" -config server.conf + +# Create a server certificate. +openssl genrsa -out serverKey.pem 2048 +# Note the CN is the DNS name of the service of the webhook. +openssl req -new -key serverKey.pem -out server.csr -subj "/CN=${CN}" -config server.conf +openssl x509 -req -in server.csr -CA caCert.pem -CAkey caKey.pem -CAcreateserial -out serverCert.pem -days 100000 -extensions v3_req -extfile server.conf + +echo "package injector" > "${OUTFILE}" +echo "" >> "${OUTFILE}" +echo "// This file was generated using openssl by the gencerts.sh script." >> "${OUTFILE}" +for file in caKey caCert serverKey serverCert; do + DATA=$(cat "${file}.pem") + echo "" >> "${OUTFILE}" + echo "var $file = []byte(\`$DATA\`)" >> "${OUTFILE}" +done + +# Copy generated code into the output directory. +cp "${OUTFILE}" "${OUTDIR}/$1" diff --git a/webhook/pkg/injector/webhook.go b/webhook/pkg/injector/webhook.go new file mode 100644 index 000000000..614b5add7 --- /dev/null +++ b/webhook/pkg/injector/webhook.go @@ -0,0 +1,211 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package injector handles mutating webhook operations. +package injector + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "os" + + "github.com/mattbaird/jsonpatch" + "gvisor.dev/gvisor/pkg/log" + admv1beta1 "k8s.io/api/admission/v1beta1" + admregv1beta1 "k8s.io/api/admissionregistration/v1beta1" + v1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kubeclientset "k8s.io/client-go/kubernetes" +) + +const ( + // Name is the name of the admission webhook service. The admission + // webhook must be exposed in the following service; this is mainly for + // the server certificate. + Name = "gvisor-injection-admission-webhook" + + // serviceNamespace is the namespace of the admission webhook service. + serviceNamespace = "e2e" + + fullName = Name + "." + serviceNamespace + ".svc" +) + +// CreateConfiguration creates MutatingWebhookConfiguration and registers the +// webhook admission controller with the kube-apiserver. The webhook will only +// take effect on pods in the namespaces selected by `podNsSelector`. If `podNsSelector` +// is empty, the webhook will take effect on all pods. +func CreateConfiguration(clientset kubeclientset.Interface, selector *metav1.LabelSelector) error { + fail := admregv1beta1.Fail + + config := &admregv1beta1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: Name, + }, + Webhooks: []admregv1beta1.MutatingWebhook{ + { + Name: fullName, + ClientConfig: admregv1beta1.WebhookClientConfig{ + Service: &admregv1beta1.ServiceReference{ + Name: Name, + Namespace: serviceNamespace, + }, + CABundle: caCert, + }, + Rules: []admregv1beta1.RuleWithOperations{ + { + Operations: []admregv1beta1.OperationType{ + admregv1beta1.Create, + }, + Rule: admregv1beta1.Rule{ + APIGroups: []string{"*"}, + APIVersions: []string{"*"}, + Resources: []string{"pods"}, + }, + }, + }, + FailurePolicy: &fail, + NamespaceSelector: selector, + }, + }, + } + log.Infof("Creating MutatingWebhookConfiguration %q", config.Name) + if _, err := clientset.AdmissionregistrationV1beta1().MutatingWebhookConfigurations().Create(config); err != nil { + if !apierrors.IsAlreadyExists(err) { + return fmt.Errorf("failed to create MutatingWebhookConfiguration %q: %s", config.Name, err) + } + log.Infof("MutatingWebhookConfiguration %q already exists; use the existing one", config.Name) + } + return nil +} + +// GetTLSConfig retrieves the CA cert that signed the cert used by the webhook. +func GetTLSConfig() *tls.Config { + serverCert, err := tls.X509KeyPair(serverCert, serverKey) + if err != nil { + log.Warningf("Failed to generate X509 key pair: %v", err) + os.Exit(1) + } + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } +} + +// Admit performs admission checks and mutations on Pods. +func Admit(writer http.ResponseWriter, req *http.Request) { + review := &admv1beta1.AdmissionReview{} + if err := json.NewDecoder(req.Body).Decode(review); err != nil { + log.Infof("Failed with error (%v) to decode Admit request: %+v", err, *req) + writer.WriteHeader(http.StatusBadRequest) + return + } + + log.Debugf("admitPod: %+v", review) + var err error + review.Response, err = admitPod(review.Request) + if err != nil { + log.Warningf("admitPod failed: %v", err) + review.Response = &admv1beta1.AdmissionResponse{ + Result: &metav1.Status{ + Reason: metav1.StatusReasonInvalid, + Message: err.Error(), + }, + } + sendResponse(writer, review) + return + } + + log.Debugf("Processed admission review: %+v", review) + sendResponse(writer, review) +} + +func sendResponse(writer http.ResponseWriter, response interface{}) { + b, err := json.Marshal(response) + if err != nil { + log.Warningf("Failed with error (%v) to marshal response: %+v", err, response) + writer.WriteHeader(http.StatusInternalServerError) + return + } + + writer.WriteHeader(http.StatusOK) + writer.Write(b) +} + +func admitPod(req *admv1beta1.AdmissionRequest) (*admv1beta1.AdmissionResponse, error) { + // Verify that the request is indeed a Pod. + resource := metav1.GroupVersionResource{Group: "", Version: "v1", Resource: "pods"} + if req.Resource != resource { + return nil, fmt.Errorf("unexpected resource %+v in pod admission", req.Resource) + } + + // Decode the request into a Pod. + pod := &v1.Pod{} + if err := json.Unmarshal(req.Object.Raw, pod); err != nil { + return nil, fmt.Errorf("failed to decode pod object %s/%s", req.Namespace, req.Name) + } + + // Copy first to change it. + podCopy := pod.DeepCopy() + updatePod(podCopy) + patch, err := createPatch(req.Object.Raw, podCopy) + if err != nil { + return nil, fmt.Errorf("failed to create patch for pod %s/%s (generatedName: %s)", pod.Namespace, pod.Name, pod.GenerateName) + } + + log.Debugf("Patched pod %s/%s (generateName: %s): %+v", pod.Namespace, pod.Name, pod.GenerateName, podCopy) + patchType := admv1beta1.PatchTypeJSONPatch + return &admv1beta1.AdmissionResponse{ + Allowed: true, + Patch: patch, + PatchType: &patchType, + }, nil +} + +func updatePod(pod *v1.Pod) { + gvisor := "gvisor" + pod.Spec.RuntimeClassName = &gvisor + + // We don't run SELinux test for gvisor. + // If SELinuxOptions are specified, this is usually for volume test to pass + // on SELinux. This can be safely ignored. + if pod.Spec.SecurityContext != nil && pod.Spec.SecurityContext.SELinuxOptions != nil { + pod.Spec.SecurityContext.SELinuxOptions = nil + } + for i := range pod.Spec.Containers { + c := &pod.Spec.Containers[i] + if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil { + c.SecurityContext.SELinuxOptions = nil + } + } + for i := range pod.Spec.InitContainers { + c := &pod.Spec.InitContainers[i] + if c.SecurityContext != nil && c.SecurityContext.SELinuxOptions != nil { + c.SecurityContext.SELinuxOptions = nil + } + } +} + +func createPatch(old []byte, newObj interface{}) ([]byte, error) { + new, err := json.Marshal(newObj) + if err != nil { + return nil, err + } + patch, err := jsonpatch.CreatePatch(old, new) + if err != nil { + return nil, err + } + return json.Marshal(patch) +} |