diff options
138 files changed, 4426 insertions, 1892 deletions
@@ -28,20 +28,12 @@ build:remote --bes_results_url="https://source.cloud.google.com/results/invocati build:remote --bes_timeout=600s build:remote --project_id=gvisor-rbe build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance -build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com -build:remote3 --project_id=gvisor-rbe -build:remote3 --bes_backend=buildeventservice.googleapis.com -build:remote3 --bes_results_url="https://source.cloud.google.com/results/invocations" -build:remote3 --bes_timeout=600s -build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance # Enable authentication. This will pick up application default credentials by # default. You can use --google_credentials=some_file.json to use a service # account credential instead. build:remote --google_default_credentials=true build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" -build:remote3 --google_default_credentials=true -build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools" # Add a custom platform and toolchain that builds in a privileged docker # container, which is required by our syscall tests. @@ -50,31 +42,5 @@ build:remote --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-defa build:remote --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604 build:remote --crosstool_top=@rbe_default//cc:toolchain -build:remote --jobs=100 +build:remote --jobs=300 build:remote --remote_timeout=3600 -build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3 -build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3 -build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 -build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3 -build:remote3 --crosstool_top=@rbe_default//cc:toolchain -build:remote3 --jobs=100 -build:remote3 --remote_timeout=3600 - -# Set flags for uploading to BES in order to view results in the Bazel Build -# Results UI. -build:results --bes_backend="buildeventservice.googleapis.com" -build:results --bes_timeout=60s -build:results --tls_enabled - -# Output BES results url -build:results --bes_results_url="https://source.cloud.google.com/results/invocations/" - -# Set flags for uploading to BES without Remote Build Execution. -build:results-local --bes_backend="buildeventservice.googleapis.com" -build:results-local --bes_timeout=60s -build:results-local --tls_enabled=true -build:results-local --auth_enabled=true -build:results-local --spawn_strategy=local -build:results-local --remote_cache=remotebuildexecution.googleapis.com -build:results-local --remote_timeout=3600 -build:results-local --bes_results_url="https://source.cloud.google.com/results/invocations/" diff --git a/.gitignore b/.gitignore index 13babef4d..a56f6ebcd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ # Generated bazel symlinks. -/bazel-* +/bazel-*
\ No newline at end of file @@ -75,12 +75,19 @@ go_path( name = "gopath", mode = "link", deps = [ - # Main binary. - "//runsc", - "//shim/v1:gvisor-containerd-shim", - "//shim/v2:containerd-shim-runsc-v1", + # Main binaries. + # + # For reasons related to reproducibility of the generated + # files, in order to ensure that :gopath produces only a + # a single "pure" version of all files, we can only depend + # on go_library targets here, and not go_binary. Thus the + # binaries have been factored into a cli package, which is + # a good practice in any case. + "//runsc/cli", + "//shim/v1/cli", + "//shim/v2/cli", - # Packages that are not dependencies of //runsc. + # Packages that are not dependencies of the above. "//pkg/sentry/kernel/memevent", "//pkg/tcpip/adapters/gonet", "//pkg/tcpip/link/channel", diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 89180eb3f..c53df7d25 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -70,10 +70,8 @@ Rules: * `@org_golang_x_sys//unix:go_default_library` (Go import `golang.org/x/sys/unix`). * Generated Go protobuf packages. - * `@com_github_golang_protobuf//proto:go_default_library` (Go import - `github.com/golang/protobuf/proto`). - * `@com_github_golang_protobuf//ptypes:go_default_library` (Go import - `github.com/golang/protobuf/ptypes`). + * `@org_golang_google_protobuf//proto:go_default_library` (Go import + `google.golang.org/protobuf`). * `runsc` may only depend on the following packages: @@ -94,9 +94,9 @@ endef rebuild-...: ## Rebuild the given image. Also may use 'rebuild-all-images'. $(eval $(call images,rebuild)) push-...: ## Push the given image. Also may use 'push-all-images'. -$(eval $(call images,pull)) -pull-...: ## Pull the given image. Also may use 'pull-all-images'. $(eval $(call images,push)) +pull-...: ## Pull the given image. Also may use 'pull-all-images'. +$(eval $(call images,pull)) load-...: ## Load (pull or rebuild) the given image. Also may use 'load-all-images'. $(eval $(call images,load)) list-images: ## List all available images. @@ -130,8 +130,7 @@ unit-tests: ## Local package unit tests in pkg/..., runsc/, tools/.., etc. .PHONY: unit-tests tests: ## Runs all unit tests and syscall tests. -tests: unit-tests - @$(call submake,test TARGETS="test/syscalls/...") +tests: unit-tests syscall-tests .PHONY: tests integration-tests: ## Run all standard integration tests. @@ -147,15 +146,14 @@ network-tests: iptables-tests packetdrill-tests packetimpact-tests INTEGRATION_TARGETS := //test/image:image_test //test/e2e:integration_test syscall-%-tests: - @$(call submake,test OPTIONS="--test_tag_filters runsc_$* test/syscalls/...") + @$(call submake,test OPTIONS="--test_tag_filters runsc_$*" TARGETS="test/syscalls/...") syscall-native-tests: - @$(call submake,test OPTIONS="--test_tag_filters native test/syscalls/...") + @$(call submake,test OPTIONS="--test_tag_filters native" TARGETS="test/syscalls/...") .PHONY: syscall-native-tests syscall-tests: ## Run all system call tests. -syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests -.PHONY: syscall-tests + @$(call submake,test TARGETS="test/syscalls/...") %-runtime-tests: load-runtimes_% @$(call submake,install-test-runtime) @@ -258,7 +256,7 @@ WEBSITE_PROJECT := gvisordev WEBSITE_REGION := us-central1 website-build: load-jekyll ## Build the site image locally. - @$(call submake,run TARGETS="//website:website") + @$(call submake,run TARGETS="//website:website" ARGS="$(WEBSITE_IMAGE)") .PHONY: website-build website-server: website-build ## Run a local server for development. @@ -266,7 +264,7 @@ website-server: website-build ## Run a local server for development. .PHONY: website-server website-push: website-build ## Push a new image and update the service. - @docker tag gvisor.dev/images/website $(WEBSITE_IMAGE) && docker push $(WEBSITE_IMAGE) + @docker push $(WEBSITE_IMAGE) .PHONY: website-push website-deploy: website-push ## Deploy a new version of the website. @@ -382,7 +380,7 @@ test-runtime: ## A convenient wrapper around test that provides the runtime argu nogo: ## Surfaces all nogo findings. @$(call submake,build OPTIONS="--build_tag_filters nogo" TARGETS="//...") - @$(call submake,run TARGETS="//tools/github" ARGS="-path=$(BUILD_ROOT) -dry-run nogo") + @$(call submake,run TARGETS="//tools/github" ARGS="$(foreach dir,$(BUILD_ROOTS),-path=$(CURDIR)/$(dir)) -dry-run nogo") .PHONY: nogo gazelle: ## Runs gazelle to update WORKSPACE. @@ -20,47 +20,28 @@ bazel_skylib_workspace() # Note that this repository actually patches some other Go repositories as it # loads it, in order to limit visibility. We hack this process by patching the # patch used by the Go rules, turning the trick against itself. + http_archive( name = "io_bazel_rules_go", + sha256 = "b725e6497741d7fc2d55fcc29a276627d10e43fa5d0bb692692890ae30d98d00", patch_args = ["-p1"], patches = [ - "//tools/nogo:io_bazel_rules_go-visibility.patch", + # Newer versions of the rules_go rules will automatically strip test + # binaries of symbols, which we don't want. + "//tools:rules_go.patch", ], - sha256 = "db2b2d35293f405430f553bc7a865a8749a8ef60c30287e90d2b278c32771afe", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.22.3/rules_go-v0.22.3.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", + "https://github.com/bazelbuild/rules_go/releases/download/v0.24.3/rules_go-v0.24.3.tar.gz", ], ) http_archive( name = "bazel_gazelle", - sha256 = "d8c45ee70ec39a57e7a05e5027c32b1576cc7f16d9dd37135b0eddde45cf1b10", - urls = [ - "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.20.0/bazel-gazelle-v0.20.0.tar.gz", - ], -) - -http_archive( - name = "io_bazel_rules_go_bazel3", # To replace the above. - patch_args = ["-p1"], - patches = [ - "//tools/nogo:io_bazel_rules_go-visibility.patch", - ], - sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", - "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz", - ], -) - -http_archive( - name = "bazel_gazelle_bazel3", # To replace the above. - sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8", + sha256 = "b85f48fa105c4403326e9525ad2b2cc437babaa6e15a3fc0b1dbab0ab064bc7c", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", - "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz", + "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.22.2/bazel-gazelle-v0.22.2.tar.gz", ], ) @@ -117,16 +98,6 @@ rules_proto_toolchains() # See releases at https://releases.bazel.build/bazel-toolchains.html http_archive( name = "bazel_toolchains", - sha256 = "239a1a673861eabf988e9804f45da3b94da28d1aff05c373b013193c315d9d9e", - strip_prefix = "bazel-toolchains-3.0.1", - urls = [ - "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz", - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.0.1/bazel-toolchains-3.0.1.tar.gz", - ], -) - -http_archive( - name = "bazel_toolchains_bazel3", # To replace the above. sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48", strip_prefix = "bazel-toolchains-3.1.1", urls = [ @@ -208,8 +179,8 @@ http_archive( go_repository( name = "com_github_sirupsen_logrus", importpath = "github.com/sirupsen/logrus", - sum = "h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=", - version = "v1.4.2", + sum = "h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=", + version = "v1.6.0", ) go_repository( @@ -357,8 +328,8 @@ go_repository( go_repository( name = "org_golang_x_tools", importpath = "golang.org/x/tools", - sum = "h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ=", - version = "v0.0.0-20200918232735-d647fc253266", + sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", + version = "v0.0.0-20201002184944-ecd9fd270d5d", ) go_repository( @@ -441,8 +412,8 @@ go_repository( go_repository( name = "com_github_konsorten_go_windows_terminal_sequences", importpath = "github.com/konsorten/go-windows-terminal-sequences", - sum = "h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s=", - version = "v1.0.2", + sum = "h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE=", + version = "v1.0.3", ) go_repository( @@ -561,8 +532,8 @@ go_repository( go_repository( name = "com_github_containerd_continuity", importpath = "github.com/containerd/continuity", - sum = "h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k=", - version = "v0.0.0-20200710164510-efbc4488d8fe", + sum = "h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM=", + version = "v0.0.0-20200928162600-f2cc35102c2a", ) go_repository( @@ -603,8 +574,8 @@ go_repository( go_repository( name = "com_github_dustin_go_humanize", importpath = "github.com/dustin/go-humanize", - sum = "h1:qk/FSDDxo05wdJH28W+p5yivv7LuLYLRXPPD8KQCtZs=", - version = "v0.0.0-20171111073723-bb3d318650d4", + sum = "h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=", + version = "v1.0.0", ) go_repository( @@ -622,13 +593,6 @@ go_repository( ) go_repository( - name = "com_github_fsnotify_fsnotify", - importpath = "github.com/fsnotify/fsnotify", - sum = "h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=", - version = "v1.4.7", -) - -go_repository( name = "com_github_godbus_dbus", importpath = "github.com/godbus/dbus", sum = "h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8=", @@ -685,13 +649,6 @@ go_repository( ) go_repository( - name = "com_github_hpcloud_tail", - importpath = "github.com/hpcloud/tail", - sum = "h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=", - version = "v1.0.0", -) - -go_repository( name = "com_github_inconshreveable_mousetrap", importpath = "github.com/inconshreveable/mousetrap", sum = "h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=", @@ -720,20 +677,6 @@ go_repository( ) go_repository( - name = "com_github_onsi_ginkgo", - importpath = "github.com/onsi/ginkgo", - sum = "h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo=", - version = "v1.10.1", -) - -go_repository( - name = "com_github_onsi_gomega", - importpath = "github.com/onsi/gomega", - sum = "h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME=", - version = "v1.7.0", -) - -go_repository( name = "com_github_opencontainers_runc", importpath = "github.com/opencontainers/runc", sum = "h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJGY8Y=", @@ -797,34 +740,6 @@ go_repository( ) go_repository( - name = "in_gopkg_airbrake_gobrake_v2", - importpath = "gopkg.in/airbrake/gobrake.v2", - sum = "h1:7z2uVWwn7oVeeugY1DtlPAy5H+KYgB1KeKTnqjNatLo=", - version = "v2.0.9", -) - -go_repository( - name = "in_gopkg_fsnotify_v1", - importpath = "gopkg.in/fsnotify.v1", - sum = "h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4=", - version = "v1.4.7", -) - -go_repository( - name = "in_gopkg_gemnasium_logrus_airbrake_hook_v2", - importpath = "gopkg.in/gemnasium/logrus-airbrake-hook.v2", - sum = "h1:OAj3g0cR6Dx/R07QgQe8wkA9RNjB2u4i700xBkIT4e0=", - version = "v2.1.2", -) - -go_repository( - name = "in_gopkg_tomb_v1", - importpath = "gopkg.in/tomb.v1", - sum = "h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=", - version = "v1.0.0-20141024135613-dd632973f1e7", -) - -go_repository( name = "in_gopkg_yaml_v2", importpath = "gopkg.in/yaml.v2", sum = "h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=", @@ -848,15 +763,15 @@ go_repository( go_repository( name = "org_golang_google_genproto", importpath = "google.golang.org/genproto", - sum = "h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM=", - version = "v0.0.0-20200117163144-32f20d992d24", + sum = "h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY=", + version = "v0.0.0-20200526211855-cb27e3aa2013", ) go_repository( name = "org_golang_google_protobuf", importpath = "google.golang.org/protobuf", - sum = "h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM=", - version = "v1.23.0", + sum = "h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w=", + version = "v1.25.1-0.20200808011614-a180de9f97d9", ) go_repository( @@ -12,7 +12,7 @@ require ( github.com/cilium/ebpf v0.0.0-20200110133405-4032b1d8aae3 // indirect github.com/containerd/cgroups v0.0.0-20181219155423-39b18af02c41 // indirect github.com/containerd/containerd v1.3.4 // indirect - github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe // indirect + github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 // indirect github.com/containerd/go-runc v0.0.0-20200220073739-7016d3ce2328 // indirect github.com/containerd/ttrpc v0.0.0-20200121165050-0be804eadb15 // indirect @@ -29,14 +29,12 @@ require ( github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e // indirect github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect github.com/gogo/googleapis v1.4.0 // indirect - github.com/golang/protobuf v1.4.2 // indirect github.com/google/go-cmp v0.5.1 // indirect github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 // indirect github.com/google/subcommands v1.0.2-0.20190508160503-636abe8753b8 // indirect github.com/hashicorp/go-multierror v1.0.0 // indirect github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 // indirect github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 // indirect - github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.1 // indirect github.com/opencontainers/runc v0.1.1 // indirect github.com/opencontainers/runtime-spec v1.0.2-0.20181111125026-1722abf79c2f // indirect @@ -48,8 +46,9 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.2.0 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect - golang.org/x/tools v0.0.0-20200918232735-d647fc253266 // indirect + golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d // indirect google.golang.org/grpc v1.29.0 // indirect + google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 // indirect gopkg.in/yaml.v2 v2.2.8 // indirect gotest.tools v2.2.0+incompatible // indirect ) @@ -48,8 +48,8 @@ github.com/containerd/containerd v1.3.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMX github.com/containerd/containerd v1.3.4 h1:3o0smo5SKY7H6AJCmJhsnCjR2/V2T8VmiHt7seN2/kI= github.com/containerd/containerd v1.3.4/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= github.com/containerd/continuity v0.0.0-20190426062206-aaeac12a7ffc/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= -github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe h1:PEmIrUvwG9Yyv+0WKZqjXfSFDeZjs/q15g0m08BYS9k= -github.com/containerd/continuity v0.0.0-20200710164510-efbc4488d8fe/go.mod h1:cECdGN1O8G9bgKTlLhuPJimka6Xb/Gg7vYzCTNVxhvo= +github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a h1:jEIoR0aA5GogXZ8pP3DUzE+zrhaF6/1rYZy+7KkYEWM= +github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a/go.mod h1:W0qIOTD7mp2He++YVq+kgfXezRYqzP1uDuMVH1bITDY= github.com/containerd/fifo v0.0.0-20190226154929-a9fb20d87448/go.mod h1:ODA38xgv3Kuk8dQz2ZQXpnv/UZZUHUCL7pnLehbXgQI= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00 h1:lsjC5ENBl+Zgf38+B0ymougXFp0BaubeIVETltYZTQw= github.com/containerd/fifo v0.0.0-20191213151349-ff969a566b00/go.mod h1:jPQ2IAeZRCYxpS/Cm1495vGFww6ecHmMk1YJH2Q5ln0= @@ -82,12 +82,11 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be h1:l+j1wSnHcimOzeeKxtspsl6tCBTyikdYxcWqFZ+Ho2c= github.com/dpjacques/clockwork v0.1.1-0.20200827220843-c1f524b839be/go.mod h1:D8mP2A8vVT2GkXqPorSBmhnshhkFBYgzhA90KmJt25Y= -github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e h1:BWhy2j3IXJhjCbC68FptL43tDKIq8FladmaTs3Xs7Z8= github.com/godbus/dbus v0.0.0-20190422162347-ade71ed3457e/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= @@ -116,8 +115,8 @@ github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:x github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.1 h1:ZFgWrT+bLgsYPirOnRfKLYJLvssAegOj/hgyMFdJZe0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -125,6 +124,7 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-github/v28 v28.1.2-0.20191108005307-e555eab49ce8 h1:zOOUQavr8D4AZrcV4ylUpbGa5j3jfeslN6Xculz3tVU= @@ -147,7 +147,6 @@ github.com/hashicorp/go-multierror v1.0.0 h1:iVjPR7a6H0tWELX5NxNe7bYopibicUzc7uP github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -158,6 +157,8 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJKvHsGVBgn1xWYf1NbHQhywc8= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1 h1:zc0R6cOw98cMengLA0fvU55mqbnN7sd/tBMLzSejp+M= @@ -165,11 +166,7 @@ github.com/kr/pty v1.1.4-0.20190131011033-7dc38fb350b1/go.mod h1:pFQYn66WHrOpPYN github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9 h1:Sha2bQdoWE5YQPTlJOL31rmce94/tYi113SlFo1xQ2c= github.com/mohae/deepcopy v0.0.0-20170308212314-bb9b5e7adda9/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/opencontainers/go-digest v0.0.0-20180430190053-c9281466c8b2/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= -github.com/opencontainers/go-digest v1.0.0-rc1/go.mod h1:cMLVZDEM3+U2I4VmLI6N8jQYUd2OVphdqWwCJHrFt2s= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.0.1 h1:JMemWkRwHx4Zj+fVxWoMCFm/8sYGGrUVojFA6h/TRcI= @@ -184,7 +181,6 @@ github.com/opencontainers/runtime-spec v1.0.2 h1:UfAcuLBJB9Coz72x1hgl8O5RVzTdNia github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -195,10 +191,11 @@ github.com/prometheus/procfs v0.0.0-20190522114515-bc1a522cf7b1/go.mod h1:TjEm7z github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= -github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/spf13/cobra v0.0.2-0.20171109065643-2da4a54c5cee/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/pflag v1.0.1-0.20171106142849-4c012f6dcd95/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -222,7 +219,6 @@ go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.2.0 h1:6I+W7f5VwC5SV9dNrZ3qXrDB9mD0dyGOi/ZJmYw03T4= go.uber.org/multierr v1.2.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= -golang.org/x/crypto v0.0.0-20171113213409-9f005a07e0d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -281,7 +277,6 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2By golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -327,8 +322,8 @@ golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200117161641-43d50277825c/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.0.0-20200918232735-d647fc253266 h1:k7tVuG0g1JwmD3Jh8oAl1vQ1C3jb4Hi/dUl1wWDBJpQ= -golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d h1:vWQvJ/Z0Lu+9/8oQ/pAYXNzbc7CMnBl+tULGVHOy3oE= +golang.org/x/tools v0.0.0-20201002184944-ecd9fd270d5d/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -354,8 +349,9 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20200115191322-ca5a22157cba/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24 h1:wDju+RU97qa0FZT0QnZDg9Uc2dH0Ql513kFvHocz+WM= google.golang.org/genproto v0.0.0-20200117163144-32f20d992d24/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -363,6 +359,7 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.0 h1:2pJjwYOdkZ9HlN4sWRYBg9ttH5bCOlsueaM+b/oYjwo= google.golang.org/grpc v1.29.0/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= @@ -370,16 +367,13 @@ google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -gopkg.in/airbrake/gobrake.v2 v2.0.9/go.mod h1:/h5ZAUhDkGaJfjzjKLSjv6zCL6O0LLBxU4K+aSYdM/U= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9 h1:poC0iCcx0QXFYlS6nuq/8K+Ng5T55k0FXdzq52hVi4w= +google.golang.org/protobuf v1.25.1-0.20200808011614-a180de9f97d9/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= diff --git a/images/defs.bzl b/images/defs.bzl new file mode 100644 index 000000000..61d7bbf73 --- /dev/null +++ b/images/defs.bzl @@ -0,0 +1,31 @@ +"""Helpers for Docker image generation.""" + +def _docker_image_impl(ctx): + importer = ctx.actions.declare_file(ctx.label.name) + importer_content = [ + "#!/bin/bash", + "set -euo pipefail", + "exec docker import " + " ".join([ + "-c '%s'" % attr + for attr in ctx.attr.statements + ]) + " " + " ".join([ + "'%s'" % f.path + for f in ctx.files.data + ]) + " $1", + "", + ] + ctx.actions.write(importer, "\n".join(importer_content), is_executable = True) + return [DefaultInfo( + runfiles = ctx.runfiles(ctx.files.data), + executable = importer, + )] + +docker_image = rule( + implementation = _docker_image_impl, + doc = "Tool to load a Docker image; takes a single parameter (image name).", + attrs = { + "statements": attr.string_list(doc = "Extra Dockerfile directives."), + "data": attr.label_list(doc = "All image data."), + }, + executable = True, +) diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD index bee28b68d..a493e3407 100644 --- a/pkg/eventchannel/BUILD +++ b/pkg/eventchannel/BUILD @@ -6,6 +6,7 @@ go_library( name = "eventchannel", srcs = [ "event.go", + "event_any.go", "rate.go", ], visibility = ["//:sandbox"], @@ -14,8 +15,9 @@ go_library( "//pkg/log", "//pkg/sync", "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_golang_protobuf//ptypes:go_default_library_gen", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/anypb:go_default_library", "@org_golang_x_time//rate:go_default_library", ], ) @@ -32,6 +34,6 @@ go_test( library = ":eventchannel", deps = [ "//pkg/sync", - "@com_github_golang_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", ], ) diff --git a/pkg/eventchannel/event.go b/pkg/eventchannel/event.go index 9a29c58bd..7172ce75d 100644 --- a/pkg/eventchannel/event.go +++ b/pkg/eventchannel/event.go @@ -24,8 +24,8 @@ import ( "fmt" "syscall" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" pb "gvisor.dev/gvisor/pkg/eventchannel/eventchannel_go_proto" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" @@ -118,22 +118,6 @@ func (me *multiEmitter) Close() error { return err } -func marshal(msg proto.Message) ([]byte, error) { - anypb, err := ptypes.MarshalAny(msg) - if err != nil { - return nil, err - } - - // Wire format is uvarint message length followed by binary proto. - bufMsg, err := proto.Marshal(anypb) - if err != nil { - return nil, err - } - p := make([]byte, binary.MaxVarintLen64) - n := binary.PutUvarint(p, uint64(len(bufMsg))) - return append(p[:n], bufMsg...), nil -} - // socketEmitter emits proto messages on a socket. type socketEmitter struct { socket *unet.Socket @@ -155,10 +139,19 @@ func SocketEmitter(fd int) (Emitter, error) { // Emit implements Emitter.Emit. func (s *socketEmitter) Emit(msg proto.Message) (bool, error) { - p, err := marshal(msg) + any, err := newAny(msg) if err != nil { return false, err } + bufMsg, err := proto.Marshal(any) + if err != nil { + return false, err + } + + // Wire format is uvarint message length followed by binary proto. + p := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(p, uint64(len(bufMsg))) + p = append(p[:n], bufMsg...) for done := 0; done < len(p); { n, err := s.socket.Write(p[done:]) if err != nil { @@ -166,6 +159,7 @@ func (s *socketEmitter) Emit(msg proto.Message) (bool, error) { } done += n } + return false, nil } @@ -189,9 +183,13 @@ func DebugEmitterFrom(inner Emitter) Emitter { } func (d *debugEmitter) Emit(msg proto.Message) (bool, error) { + text, err := prototext.Marshal(msg) + if err != nil { + return false, err + } ev := &pb.DebugEvent{ - Name: proto.MessageName(msg), - Text: proto.MarshalTextString(msg), + Name: string(msg.ProtoReflect().Descriptor().FullName()), + Text: string(text), } return d.inner.Emit(ev) } diff --git a/pkg/eventchannel/event.proto b/pkg/eventchannel/event.proto index 34468f072..4b24ac47c 100644 --- a/pkg/eventchannel/event.proto +++ b/pkg/eventchannel/event.proto @@ -16,7 +16,7 @@ syntax = "proto3"; package gvisor; -// A debug event encapsulates any other event protobuf in text format. This is +// DebugEvent encapsulates any other event protobuf in text format. This is // useful because clients reading events emitted this way do not need to link // the event protobufs to display them in a human-readable format. message DebugEvent { diff --git a/pkg/eventchannel/event_any.go b/pkg/eventchannel/event_any.go new file mode 100644 index 000000000..a5549f6cd --- /dev/null +++ b/pkg/eventchannel/event_any.go @@ -0,0 +1,25 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package eventchannel + +import ( + "google.golang.org/protobuf/types/known/anypb" + + "google.golang.org/protobuf/proto" +) + +func newAny(m proto.Message) (*anypb.Any, error) { + return anypb.New(m) +} diff --git a/pkg/eventchannel/event_test.go b/pkg/eventchannel/event_test.go index 43750360b..0dd408f76 100644 --- a/pkg/eventchannel/event_test.go +++ b/pkg/eventchannel/event_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/sync" ) diff --git a/pkg/eventchannel/rate.go b/pkg/eventchannel/rate.go index 179226c92..74960e16a 100644 --- a/pkg/eventchannel/rate.go +++ b/pkg/eventchannel/rate.go @@ -15,8 +15,8 @@ package eventchannel import ( - "github.com/golang/protobuf/proto" "golang.org/x/time/rate" + "google.golang.org/protobuf/proto" ) // rateLimitedEmitter wraps an emitter and limits events to the given limits. diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD index 58305009d..0a6a5d215 100644 --- a/pkg/metric/BUILD +++ b/pkg/metric/BUILD @@ -27,6 +27,6 @@ go_test( deps = [ ":metric_go_proto", "//pkg/eventchannel", - "@com_github_golang_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", ], ) diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index c425ea532..aefd0ea5c 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -17,7 +17,7 @@ package metric import ( "testing" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/eventchannel" pb "gvisor.dev/gvisor/pkg/metric/metric_go_proto" ) diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 8cf5b35d3..1e11b0428 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -21,14 +21,16 @@ go_library( "directory.go", "filesystem.go", "fstree.go", - "non_directory.go", "overlay.go", + "regular_file.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/log", + "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", @@ -37,5 +39,6 @@ go_library( "//pkg/sync", "//pkg/syserror", "//pkg/usermem", + "//pkg/waiter", ], ) diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 73b126669..4506642ca 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -75,8 +75,21 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return syserror.ENOENT } - // Perform copy-up. + // Obtain settable timestamps from the lower layer. vfsObj := d.fs.vfsfs.VirtualFilesystem() + oldpop := vfs.PathOperation{ + Root: d.lowerVDs[0], + Start: d.lowerVDs[0], + } + const timestampsMask = linux.STATX_ATIME | linux.STATX_MTIME + oldStat, err := vfsObj.StatAt(ctx, d.fs.creds, &oldpop, &vfs.StatOptions{ + Mask: timestampsMask, + }) + if err != nil { + return err + } + + // Perform copy-up. newpop := vfs.PathOperation{ Root: d.parent.upperVD, Start: d.parent.upperVD, @@ -101,10 +114,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } switch ftype { case linux.S_IFREG: - oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }, &vfs.OpenOptions{ + oldFD, err := vfsObj.OpenAt(ctx, d.fs.creds, &oldpop, &vfs.OpenOptions{ Flags: linux.O_RDONLY, }) if err != nil { @@ -160,9 +170,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := newFD.SetStat(ctx, vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -179,9 +191,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -195,10 +209,7 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { d.upperVD = upperVD case linux.S_IFLNK: - target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }) + target, err := vfsObj.ReadlinkAt(ctx, d.fs.creds, &oldpop) if err != nil { return err } @@ -207,10 +218,12 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID, - Mode: uint16(d.mode), - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + Mode: uint16(d.mode), + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() @@ -224,25 +237,20 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { d.upperVD = upperVD case linux.S_IFBLK, linux.S_IFCHR: - lowerStat, err := vfsObj.StatAt(ctx, d.fs.creds, &vfs.PathOperation{ - Root: d.lowerVDs[0], - Start: d.lowerVDs[0], - }, &vfs.StatOptions{}) - if err != nil { - return err - } if err := vfsObj.MknodAt(ctx, d.fs.creds, &newpop, &vfs.MknodOptions{ Mode: linux.FileMode(d.mode), - DevMajor: lowerStat.RdevMajor, - DevMinor: lowerStat.RdevMinor, + DevMajor: oldStat.RdevMajor, + DevMinor: oldStat.RdevMinor, }); err != nil { return err } if err := vfsObj.SetStatAt(ctx, d.fs.creds, &newpop, &vfs.SetStatOptions{ Stat: linux.Statx{ - Mask: linux.STATX_UID | linux.STATX_GID, - UID: d.uid, - GID: d.gid, + Mask: linux.STATX_UID | linux.STATX_GID | oldStat.Mask×tampsMask, + UID: d.uid, + GID: d.gid, + Atime: oldStat.Atime, + Mtime: oldStat.Mtime, }, }); err != nil { cleanupUndoCopyUp() diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index bd11372d5..78a01bbb7 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -765,7 +765,7 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - if mayWrite { + if start.isRegularFile() && mayWrite { if err := start.copyUpLocked(ctx); err != nil { return nil, err } @@ -819,7 +819,7 @@ afterTrailingSymlink: if rp.MustBeDir() && !child.isDir() { return nil, syserror.ENOTDIR } - if mayWrite { + if child.isRegularFile() && mayWrite { if err := child.copyUpLocked(ctx); err != nil { return nil, err } @@ -872,8 +872,11 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * if err != nil { return nil, err } + if ftype != linux.S_IFREG { + return layerFD, nil + } layerFlags := layerFD.StatusFlags() - fd := &nonDirectoryFD{ + fd := ®ularFileFD{ copiedUp: isUpper, cachedFD: layerFD, cachedFlags: layerFlags, @@ -969,7 +972,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Finally construct the overlay FD. upperFlags := upperFD.StatusFlags() - fd := &nonDirectoryFD{ + fd := ®ularFileFD{ copiedUp: true, cachedFD: upperFD, cachedFlags: upperFlags, @@ -1293,6 +1296,9 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if !child.isDir() { return syserror.ENOTDIR } + if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(atomic.LoadUint32(&parent.mode)), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + return err + } child.dirMu.Lock() defer child.dirMu.Unlock() whiteouts, err := child.collectWhiteoutsForRmdirLocked(ctx) @@ -1528,12 +1534,38 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return err } + parentMode := atomic.LoadUint32(&parent.mode) child := parent.children[name] var childLayer lookupLayer + if child == nil { + if parentMode&linux.S_ISVTX != 0 { + // If the parent's sticky bit is set, we need a child dentry to get + // its owner. + child, err = fs.getChildLocked(ctx, parent, name, &ds) + if err != nil { + return err + } + } else { + // Determine if the file being unlinked actually exists. Holding + // parent.dirMu prevents a dentry from being instantiated for the file, + // which in turn prevents it from being copied-up, so this result is + // stable. + childLayer, err = fs.lookupLayerLocked(ctx, parent, name) + if err != nil { + return err + } + if !childLayer.existsInOverlay() { + return syserror.ENOENT + } + } + } if child != nil { if child.isDir() { return syserror.EISDIR } + if err := vfs.CheckDeleteSticky(rp.Credentials(), linux.FileMode(parentMode), auth.KUID(atomic.LoadUint32(&child.uid))); err != nil { + return err + } if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { return err } @@ -1546,18 +1578,6 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } else { childLayer = lookupLayerLower } - } else { - // Determine if the file being unlinked actually exists. Holding - // parent.dirMu prevents a dentry from being instantiated for the file, - // which in turn prevents it from being copied-up, so this result is - // stable. - childLayer, err = fs.lookupLayerLocked(ctx, parent, name) - if err != nil { - return err - } - if !childLayer.existsInOverlay() { - return syserror.ENOENT - } } pop := vfs.PathOperation{ diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index e5f506d2e..4c5de8d32 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -18,7 +18,7 @@ // // Lock order: // -// directoryFD.mu / nonDirectoryFD.mu +// directoryFD.mu / regularFileFD.mu // filesystem.renameMu // dentry.dirMu // dentry.copyMu @@ -453,7 +453,7 @@ type dentry struct { // - If this dentry is copied-up, then wrappedMappable is the Mappable // obtained from a call to the current top layer's // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil - // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil. + // (from a call to regularFileFD.ensureMappable()), it cannot become nil. // wrappedMappable is protected by mapsMu and dataMu. // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 853aee951..2b89a7a6d 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -19,14 +19,21 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" ) +func (d *dentry) isRegularFile() bool { + return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFREG +} + func (d *dentry) isSymlink() bool { return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFLNK } @@ -40,7 +47,7 @@ func (d *dentry) readlink(ctx context.Context) (string, error) { } // +stateify savable -type nonDirectoryFD struct { +type regularFileFD struct { fileDescription // If copiedUp is false, cachedFD represents @@ -52,9 +59,13 @@ type nonDirectoryFD struct { copiedUp bool cachedFD *vfs.FileDescription cachedFlags uint32 + + // If copiedUp is false, lowerWaiters contains all waiter.Entries + // registered with cachedFD. lowerWaiters is protected by mu. + lowerWaiters map[*waiter.Entry]waiter.EventMask } -func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) { +func (fd *regularFileFD) getCurrentFD(ctx context.Context) (*vfs.FileDescription, error) { fd.mu.Lock() defer fd.mu.Unlock() wrappedFD, err := fd.currentFDLocked(ctx) @@ -65,7 +76,7 @@ func (fd *nonDirectoryFD) getCurrentFD(ctx context.Context) (*vfs.FileDescriptio return wrappedFD, nil } -func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) { +func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescription, error) { d := fd.dentry() statusFlags := fd.vfsfd.StatusFlags() if !fd.copiedUp && d.isCopiedUp() { @@ -87,10 +98,21 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip return nil, err } } + if len(fd.lowerWaiters) != 0 { + ready := upperFD.Readiness(^waiter.EventMask(0)) + for e, mask := range fd.lowerWaiters { + fd.cachedFD.EventUnregister(e) + upperFD.EventRegister(e, mask) + if ready&mask != 0 { + e.Callback.Callback(e) + } + } + } fd.cachedFD.DecRef(ctx) fd.copiedUp = true fd.cachedFD = upperFD fd.cachedFlags = statusFlags + fd.lowerWaiters = nil } else if fd.cachedFlags != statusFlags { if err := fd.cachedFD.SetStatusFlags(ctx, d.fs.creds, statusFlags); err != nil { return nil, err @@ -101,13 +123,13 @@ func (fd *nonDirectoryFD) currentFDLocked(ctx context.Context) (*vfs.FileDescrip } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *nonDirectoryFD) Release(ctx context.Context) { +func (fd *regularFileFD) Release(ctx context.Context) { fd.cachedFD.DecRef(ctx) fd.cachedFD = nil } // OnClose implements vfs.FileDescriptionImpl.OnClose. -func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { +func (fd *regularFileFD) OnClose(ctx context.Context) error { // Linux doesn't define ovl_file_operations.flush at all (i.e. its // equivalent to OnClose is a no-op). We pass through to // fd.cachedFD.OnClose() without upgrading if fd.dentry() has been @@ -128,7 +150,7 @@ func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { } // Stat implements vfs.FileDescriptionImpl.Stat. -func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { +func (fd *regularFileFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx if layerMask := opts.Mask &^ statInternalMask; layerMask != 0 { wrappedFD, err := fd.getCurrentFD(ctx) @@ -149,7 +171,7 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux } // Allocate implements vfs.FileDescriptionImpl.Allocate. -func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { +func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return err @@ -159,7 +181,7 @@ func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uin } // SetStat implements vfs.FileDescriptionImpl.SetStat. -func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { +func (fd *regularFileFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() mode := linux.FileMode(atomic.LoadUint32(&d.mode)) if err := vfs.CheckSetStat(ctx, auth.CredentialsFromContext(ctx), &opts, mode, auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))); err != nil { @@ -191,12 +213,61 @@ func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) } // StatFS implements vfs.FileDescriptionImpl.StatFS. -func (fd *nonDirectoryFD) StatFS(ctx context.Context) (linux.Statfs, error) { +func (fd *regularFileFD) StatFS(ctx context.Context) (linux.Statfs, error) { return fd.filesystem().statFS(ctx) } +// Readiness implements waiter.Waitable.Readiness. +func (fd *regularFileFD) Readiness(mask waiter.EventMask) waiter.EventMask { + ctx := context.Background() + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + // TODO(b/171089913): Just use fd.cachedFD since Readiness can't return + // an error. This is obviously wrong, but at least consistent with + // VFS1. + log.Warningf("overlay.regularFileFD.Readiness: currentFDLocked failed: %v", err) + fd.mu.Lock() + wrappedFD = fd.cachedFD + wrappedFD.IncRef() + fd.mu.Unlock() + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (fd *regularFileFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + fd.mu.Lock() + defer fd.mu.Unlock() + wrappedFD, err := fd.currentFDLocked(context.Background()) + if err != nil { + // TODO(b/171089913): Just use fd.cachedFD since EventRegister can't + // return an error. This is obviously wrong, but at least consistent + // with VFS1. + log.Warningf("overlay.regularFileFD.EventRegister: currentFDLocked failed: %v", err) + wrappedFD = fd.cachedFD + } + wrappedFD.EventRegister(e, mask) + if !fd.copiedUp { + if fd.lowerWaiters == nil { + fd.lowerWaiters = make(map[*waiter.Entry]waiter.EventMask) + } + fd.lowerWaiters[e] = mask + } +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (fd *regularFileFD) EventUnregister(e *waiter.Entry) { + fd.mu.Lock() + defer fd.mu.Unlock() + fd.cachedFD.EventUnregister(e) + if !fd.copiedUp { + delete(fd.lowerWaiters, e) + } +} + // PRead implements vfs.FileDescriptionImpl.PRead. -func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return 0, err @@ -206,7 +277,7 @@ func (fd *nonDirectoryFD) PRead(ctx context.Context, dst usermem.IOSequence, off } // Read implements vfs.FileDescriptionImpl.Read. -func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // Hold fd.mu during the read to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -218,7 +289,7 @@ func (fd *nonDirectoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts } // PWrite implements vfs.FileDescriptionImpl.PWrite. -func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { wrappedFD, err := fd.getCurrentFD(ctx) if err != nil { return 0, err @@ -228,7 +299,7 @@ func (fd *nonDirectoryFD) PWrite(ctx context.Context, src usermem.IOSequence, of } // Write implements vfs.FileDescriptionImpl.Write. -func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // Hold fd.mu during the write to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -240,7 +311,7 @@ func (fd *nonDirectoryFD) Write(ctx context.Context, src usermem.IOSequence, opt } // Seek implements vfs.FileDescriptionImpl.Seek. -func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { +func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { // Hold fd.mu during the seek to serialize the file offset. fd.mu.Lock() defer fd.mu.Unlock() @@ -252,7 +323,7 @@ func (fd *nonDirectoryFD) Seek(ctx context.Context, offset int64, whence int32) } // Sync implements vfs.FileDescriptionImpl.Sync. -func (fd *nonDirectoryFD) Sync(ctx context.Context) error { +func (fd *regularFileFD) Sync(ctx context.Context) error { fd.mu.Lock() if !fd.dentry().isCopiedUp() { fd.mu.Unlock() @@ -269,8 +340,18 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { return wrappedFD.Sync(ctx) } +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *regularFileFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return 0, err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Ioctl(ctx, uio, args) +} + // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. -func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { +func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { if err := fd.ensureMappable(ctx, opts); err != nil { return err } @@ -278,7 +359,7 @@ func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOp } // ensureMappable ensures that fd.dentry().wrappedMappable is not nil. -func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { +func (fd *regularFileFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { d := fd.dentry() // Fast path if we already have a Mappable for the current top layer. diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 1813269e0..738c0c9cc 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -147,7 +147,12 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns FSContext: kernel.NewFSContextVFS2(root, cwd, 0022), FDTable: k.NewFDTable(), } - return k.TaskSet().NewTask(config) + t, err := k.TaskSet().NewTask(ctx, config) + if err != nil { + config.ThreadGroup.Release(ctx) + return nil, err + } + return t, nil } func newFakeExecutable(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, root vfs.VirtualDentry) (*vfs.FileDescription, error) { diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 3b3c8725f..03da505e1 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -377,12 +377,12 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // enabled, we should verify the child hash here because it may // be cached before enabled. if fs.allowRuntimeEnable { - if isEnabled(parent) { + if parent.verityEnabled() { if _, err := fs.verifyChild(ctx, parent, child); err != nil { return nil, err } } - if isEnabled(child) { + if child.verityEnabled() { vfsObj := fs.vfsfs.VirtualFilesystem() mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ @@ -553,13 +553,13 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // Verify child hash. This should always be performed unless in // allowRuntimeEnable mode and the parent directory hasn't been enabled // yet. - if isEnabled(parent) { + if parent.verityEnabled() { if _, err := fs.verifyChild(ctx, parent, child); err != nil { child.destroyLocked(ctx) return nil, err } } - if isEnabled(child) { + if child.verityEnabled() { if err := fs.verifyStat(ctx, child, stat); err != nil { child.destroyLocked(ctx) return nil, err @@ -915,7 +915,7 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } - if isEnabled(d) { + if d.verityEnabled() { if err := fs.verifyStat(ctx, d, stat); err != nil { return linux.Statx{}, err } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 70034280b..8dc9e26bc 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -148,14 +148,6 @@ func (FilesystemType) Name() string { return Name } -// isEnabled checks whether the target is enabled with verity features. It -// should always be true if runtime enable is not allowed. In runtime enable -// mode, it returns true if the target has been enabled with -// ioctl(FS_IOC_ENABLE_VERITY). -func isEnabled(d *dentry) bool { - return !d.fs.allowRuntimeEnable || len(d.hash) != 0 -} - // Release implements vfs.FilesystemType.Release. func (FilesystemType) Release(ctx context.Context) {} @@ -448,6 +440,14 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +// verityEnabled checks whether the file is enabled with verity features. It +// should always be true if runtime enable is not allowed. In runtime enable +// mode, it returns true if the target has been enabled with +// ioctl(FS_IOC_ENABLE_VERITY). +func (d *dentry) verityEnabled() bool { + return !d.fs.allowRuntimeEnable || len(d.hash) != 0 +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -510,7 +510,7 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu if err != nil { return linux.Statx{}, err } - if isEnabled(fd.d) { + if fd.d.verityEnabled() { if err := fd.d.fs.verifyStat(ctx, fd.d, stat); err != nil { return linux.Statx{}, err } @@ -726,7 +726,7 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in // allowRuntimeEnable mode. - if !isEnabled(fd.d) { + if !fd.d.verityEnabled() { return fd.lowerFD.PRead(ctx, dst, offset, opts) } diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 5de70aecb..c0de72eef 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -97,6 +97,17 @@ go_template_instance( ) go_template_instance( + name = "ipc_namespace_refs", + out = "ipc_namespace_refs.go", + package = "kernel", + prefix = "IPCNamespace", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "IPCNamespace", + }, +) + +go_template_instance( name = "process_group_refs", out = "process_group_refs.go", package = "kernel", @@ -137,6 +148,7 @@ go_library( "fs_context.go", "fs_context_refs.go", "ipc_namespace.go", + "ipc_namespace_refs.go", "kcov.go", "kcov_unsafe.go", "kernel.go", @@ -206,6 +218,7 @@ go_library( "//pkg/amutex", "//pkg/bits", "//pkg/bpf", + "//pkg/cleanup", "//pkg/context", "//pkg/coverage", "//pkg/cpuid", diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index dd5f0f5fa..bb94769c4 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -81,7 +81,8 @@ func UTSNamespaceFromContext(ctx context.Context) *UTSNamespace { } // IPCNamespaceFromContext returns the IPC namespace in which ctx is executing, -// or nil if there is no such IPC namespace. +// or nil if there is no such IPC namespace. It takes a reference on the +// namespace. func IPCNamespaceFromContext(ctx context.Context) *IPCNamespace { if v := ctx.Value(CtxIPCNamespace); v != nil { return v.(*IPCNamespace) diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 80a070d7e..3f34ee0db 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -15,6 +15,7 @@ package kernel import ( + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/semaphore" "gvisor.dev/gvisor/pkg/sentry/kernel/shm" @@ -24,6 +25,8 @@ import ( // // +stateify savable type IPCNamespace struct { + IPCNamespaceRefs + // User namespace which owns this IPC namespace. Immutable. userNS *auth.UserNamespace @@ -33,11 +36,13 @@ type IPCNamespace struct { // NewIPCNamespace creates a new IPC namespace. func NewIPCNamespace(userNS *auth.UserNamespace) *IPCNamespace { - return &IPCNamespace{ + ns := &IPCNamespace{ userNS: userNS, semaphores: semaphore.NewRegistry(userNS), shms: shm.NewRegistry(userNS), } + ns.EnableLeakCheck() + return ns } // SemaphoreRegistry returns the semaphore set registry for this namespace. @@ -50,6 +55,13 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } +// DecRef implements refs_vfs2.RefCounter.DecRef. +func (i *IPCNamespace) DecRef(ctx context.Context) { + i.IPCNamespaceRefs.DecRef(func() { + i.shms.Release(ctx) + }) +} + // IPCNamespace returns the task's IPC namespace. func (t *Task) IPCNamespace() *IPCNamespace { t.mu.Lock() diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 675506269..0eb2bf7bd 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -39,6 +39,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/eventchannel" @@ -340,7 +341,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("Timekeeper is nil") } if args.Timekeeper.clocks == nil { - return fmt.Errorf("Must call Timekeeper.SetClocks() before Kernel.Init()") + return fmt.Errorf("must call Timekeeper.SetClocks() before Kernel.Init()") } if args.RootUserNamespace == nil { return fmt.Errorf("RootUserNamespace is nil") @@ -365,7 +366,7 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.useHostCores = true maxCPU, err := hostcpu.MaxPossibleCPU() if err != nil { - return fmt.Errorf("Failed to get maximum CPU number: %v", err) + return fmt.Errorf("failed to get maximum CPU number: %v", err) } minAppCores := uint(maxCPU) + 1 if k.applicationCores < minAppCores { @@ -828,7 +829,9 @@ func (ctx *createProcessContext) Value(key interface{}) interface{} { case CtxUTSNamespace: return ctx.args.UTSNamespace case CtxIPCNamespace: - return ctx.args.IPCNamespace + ipcns := ctx.args.IPCNamespace + ipcns.IncRef() + return ipcns case auth.CtxCredentials: return ctx.args.Credentials case fs.CtxRoot: @@ -964,6 +967,10 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, } tg := k.NewThreadGroup(mntns, args.PIDNamespace, NewSignalHandlers(), linux.SIGCHLD, args.Limits) + cu := cleanup.Make(func() { + tg.Release(ctx) + }) + defer cu.Clean() // Check which file to start from. switch { @@ -1023,13 +1030,14 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, MountNamespaceVFS2: mntnsVFS2, ContainerID: args.ContainerID, } - t, err := k.tasks.NewTask(config) + t, err := k.tasks.NewTask(ctx, config) if err != nil { return nil, 0, err } t.traceExecEvent(tc) // Simulate exec for tracing. // Success. + cu.Release() tgid := k.tasks.Root.IDOfThreadGroup(tg) if k.globalInit == nil { k.globalInit = tg @@ -1374,8 +1382,9 @@ func (k *Kernel) RootUTSNamespace() *UTSNamespace { return k.rootUTSNamespace } -// RootIPCNamespace returns the root IPCNamespace. +// RootIPCNamespace takes a reference and returns the root IPCNamespace. func (k *Kernel) RootIPCNamespace() *IPCNamespace { + k.rootIPCNamespace.IncRef() return k.rootIPCNamespace } @@ -1636,7 +1645,9 @@ func (ctx supervisorContext) Value(key interface{}) interface{} { case CtxUTSNamespace: return ctx.k.rootUTSNamespace case CtxIPCNamespace: - return ctx.k.rootIPCNamespace + ipcns := ctx.k.rootIPCNamespace + ipcns.IncRef() + return ipcns case auth.CtxCredentials: // The supervisor context is global root. return auth.NewRootCredentials(ctx.k.rootUserNamespace) diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index f61039f5b..1a152142b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -237,8 +237,7 @@ func (fd *VFSPipeFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal // PipeSize implements fcntl(F_GETPIPE_SZ). func (fd *VFSPipeFD) PipeSize() int64 { - // Inline Pipe.FifoSize() rather than calling it with nil Context and - // fs.File and ignoring the returned error (which is always nil). + // Inline Pipe.FifoSize() since we don't have a fs.File. fd.pipe.mu.Lock() defer fd.pipe.mu.Unlock() return fd.pipe.max diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index b7e4b480d..f8a382fd8 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -27,6 +27,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refs_vfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 00c03585e..ebbebf46b 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -321,9 +321,32 @@ func (r *Registry) remove(s *Shm) { r.totalPages -= s.effectiveSize / usermem.PageSize } +// Release drops the self-reference of each active shm segment in the registry. +// It is called when the kernel.IPCNamespace containing r is being destroyed. +func (r *Registry) Release(ctx context.Context) { + // Because Shm.DecRef() may acquire the same locks, collect the segments to + // release first. Note that this should not race with any updates to r, since + // the IPC namespace containing it has no more references. + toRelease := make([]*Shm, 0) + r.mu.Lock() + for _, s := range r.keysToShms { + s.mu.Lock() + if !s.pendingDestruction { + toRelease = append(toRelease, s) + } + s.mu.Unlock() + } + r.mu.Unlock() + + for _, s := range toRelease { + r.dissociateKey(s) + s.DecRef(ctx) + } +} + // Shm represents a single shared memory segment. // -// Shm segment are backed directly by an allocation from platform memory. +// Shm segments are backed directly by an allocation from platform memory. // Segments are always mapped as a whole, greatly simplifying how mappings are // tracked. However note that mremap and munmap calls may cause the vma for a // segment to become fragmented; which requires special care when unmapping a @@ -652,17 +675,20 @@ func (s *Shm) MarkDestroyed(ctx context.Context) { s.registry.dissociateKey(s) s.mu.Lock() - defer s.mu.Unlock() - if !s.pendingDestruction { - s.pendingDestruction = true - // Drop the self-reference so destruction occurs when all - // external references are gone. - // - // N.B. This cannot be the final DecRef, as the caller also - // holds a reference. - s.DecRef(ctx) + if s.pendingDestruction { + s.mu.Unlock() return } + s.pendingDestruction = true + s.mu.Unlock() + + // Drop the self-reference so destruction occurs when all + // external references are gone. + // + // N.B. This cannot be the final DecRef, as the caller also + // holds a reference. + s.DecRef(ctx) + return } // checkOwnership verifies whether a segment may be accessed by ctx as an diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index e90a19cfb..037971393 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -656,7 +656,9 @@ func (t *Task) Value(key interface{}) interface{} { case CtxUTSNamespace: return t.utsns case CtxIPCNamespace: - return t.ipcns + ipcns := t.IPCNamespace() + ipcns.IncRef() + return ipcns case CtxTask: return t case auth.CtxCredentials: diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index fce1064a7..682080c14 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -203,7 +204,13 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC // namespace" ipcns = NewIPCNamespace(userns) + } else { + ipcns.IncRef() } + cu := cleanup.Make(func() { + ipcns.DecRef(t) + }) + defer cu.Clean() netns := t.NetworkNamespace() if opts.NewNetworkNamespace { @@ -214,12 +221,18 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { mntnsVFS2 := t.mountNamespaceVFS2 if mntnsVFS2 != nil { mntnsVFS2.IncRef() + cu.Add(func() { + mntnsVFS2.DecRef(t) + }) } tc, err := t.tc.Fork(t, t.k, !opts.NewAddressSpace) if err != nil { return 0, nil, err } + cu.Add(func() { + tc.release() + }) // clone() returns 0 in the child. tc.Arch.SetReturn(0) if opts.Stack != 0 { @@ -295,11 +308,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } else { cfg.InheritParent = t } - nt, err := t.tg.pidns.owner.NewTask(cfg) + nt, err := t.tg.pidns.owner.NewTask(t, cfg) + // If NewTask succeeds, we transfer references to nt. If NewTask fails, it does + // the cleanup for us. + cu.Release() if err != nil { - if opts.NewThreadGroup { - tg.release(t) - } return 0, nil, err } @@ -509,6 +522,7 @@ func (t *Task) Unshare(opts *SharingOptions) error { } // Note that "If CLONE_NEWIPC is set, then create the process in a new IPC // namespace" + t.ipcns.DecRef(t) t.ipcns = NewIPCNamespace(creds.UserNamespace) } var oldFDTable *FDTable diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index b400a8b41..ce7b9641d 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -280,12 +280,13 @@ func (*runExitMain) execute(t *Task) taskRunState { t.mountNamespaceVFS2.DecRef(t) t.mountNamespaceVFS2 = nil } + t.ipcns.DecRef(t) t.mu.Unlock() // If this is the last task to exit from the thread group, release the // thread group's resources. if lastExiter { - t.tg.release(t) + t.tg.Release(t) } // Detach tracees. diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 64c1e120a..8e28230cc 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -98,14 +99,18 @@ type TaskConfig struct { // NewTask creates a new task defined by cfg. // // NewTask does not start the returned task; the caller must call Task.Start. -func (ts *TaskSet) NewTask(cfg *TaskConfig) (*Task, error) { +// +// If successful, NewTask transfers references held by cfg to the new task. +// Otherwise, NewTask releases them. +func (ts *TaskSet) NewTask(ctx context.Context, cfg *TaskConfig) (*Task, error) { t, err := ts.newTask(cfg) if err != nil { cfg.TaskContext.release() - cfg.FSContext.DecRef(t) - cfg.FDTable.DecRef(t) + cfg.FSContext.DecRef(ctx) + cfg.FDTable.DecRef(ctx) + cfg.IPCNamespace.DecRef(ctx) if cfg.MountNamespaceVFS2 != nil { - cfg.MountNamespaceVFS2.DecRef(t) + cfg.MountNamespaceVFS2.DecRef(ctx) } return nil, err } diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index 0b34c0099..a183b28c1 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -307,8 +308,8 @@ func (tg *ThreadGroup) Limits() *limits.LimitSet { return tg.limits } -// release releases the thread group's resources. -func (tg *ThreadGroup) release(t *Task) { +// Release releases the thread group's resources. +func (tg *ThreadGroup) Release(ctx context.Context) { // Timers must be destroyed without holding the TaskSet or signal mutexes // since timers send signals with Timer.mu locked. tg.itimerRealTimer.Destroy() @@ -325,7 +326,7 @@ func (tg *ThreadGroup) release(t *Task) { it.DestroyTimer() } if tg.mounts != nil { - tg.mounts.DecRef(t) + tg.mounts.DecRef(ctx) } } diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD index 7a3311a70..5b09b9feb 100644 --- a/pkg/sentry/pgalloc/BUILD +++ b/pkg/sentry/pgalloc/BUILD @@ -83,6 +83,7 @@ go_library( ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/abi/linux", "//pkg/context", "//pkg/log", "//pkg/memutil", diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 626d1eaa4..7c297fb9e 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -29,6 +29,7 @@ import ( "syscall" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" @@ -224,6 +225,18 @@ type usageInfo struct { refs uint64 } +// canCommit returns true if the tracked region can be committed. +func (u *usageInfo) canCommit() bool { + // refs must be greater than 0 because we assume that reclaimable pages + // (that aren't already known to be committed) are not committed. This + // isn't necessarily true, even after the reclaimer does Decommit(), + // because the kernel may subsequently back the hugepage-sized region + // containing the decommitted page with a hugepage. However, it's + // consistent with our treatment of unallocated pages, which have the same + // property. + return !u.knownCommitted && u.refs != 0 +} + // An EvictableMemoryUser represents a user of MemoryFile-allocated memory that // may be asked to deallocate that memory in the presence of memory pressure. type EvictableMemoryUser interface { @@ -828,6 +841,11 @@ func (f *MemoryFile) UpdateUsage() error { log.Debugf("UpdateUsage: skipped with usageSwapped!=0.") return nil } + // Linux updates usage values at CONFIG_HZ. + if scanningAfter := time.Now().Sub(f.usageLast).Milliseconds(); scanningAfter < time.Second.Milliseconds()/linux.CLOCKS_PER_SEC { + log.Debugf("UpdateUsage: skipped because previous scan happened %d ms back", scanningAfter) + return nil + } f.usageLast = time.Now() err = f.updateUsageLocked(currentUsage, mincore) @@ -841,7 +859,7 @@ func (f *MemoryFile) UpdateUsage() error { // pages by invoking checkCommitted, which is a function that, for each page i // in bs, sets committed[i] to 1 if the page is committed and 0 otherwise. // -// Precondition: f.mu must be held. +// Precondition: f.mu must be held; it may be unlocked and reacquired. func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(bs []byte, committed []byte) error) error { // Track if anything changed to elide the merge. In the common case, we // expect all segments to be committed and no merge to occur. @@ -868,7 +886,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( } else if f.usageSwapped != 0 { // We have more usage accounted for than the file itself. // That's fine, we probably caught a race where pages were - // being committed while the above loop was running. Just + // being committed while the below loop was running. Just // report the higher number that we found and ignore swap. usage.MemoryAccounting.Dec(f.usageSwapped, usage.System) f.usageSwapped = 0 @@ -880,21 +898,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( // Iterate over all usage data. There will only be usage segments // present when there is an associated reference. - for seg := f.usage.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { - val := seg.Value() - - // Already known to be committed; ignore. - if val.knownCommitted { - continue - } - - // Assume that reclaimable pages (that aren't already known to be - // committed) are not committed. This isn't necessarily true, even - // after the reclaimer does Decommit(), because the kernel may - // subsequently back the hugepage-sized region containing the - // decommitted page with a hugepage. However, it's consistent with our - // treatment of unallocated pages, which have the same property. - if val.refs == 0 { + for seg := f.usage.FirstSegment(); seg.Ok(); { + if !seg.ValuePtr().canCommit() { + seg = seg.NextSegment() continue } @@ -917,56 +923,53 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( } // Query for new pages in core. - if err := checkCommitted(s, buf); err != nil { + // NOTE(b/165896008): mincore (which is passed as checkCommitted) + // by f.UpdateUsage() might take a really long time. So unlock f.mu + // while checkCommitted runs. + f.mu.Unlock() + err := checkCommitted(s, buf) + f.mu.Lock() + if err != nil { checkErr = err return } // Scan each page and switch out segments. - populatedRun := false - populatedRunStart := 0 - for i := 0; i <= bufLen; i++ { - // We run past the end of the slice here to - // simplify the logic and only set populated if - // we're still looking at elements. - populated := false - if i < bufLen { - populated = buf[i]&0x1 != 0 - } - - switch { - case populated == populatedRun: - // Keep the run going. - continue - case populated && !populatedRun: - // Begin the run. - populatedRun = true - populatedRunStart = i - // Keep going. + seg := f.usage.LowerBoundSegment(r.Start) + for i := 0; i < bufLen; { + if buf[i]&0x1 == 0 { + i++ continue - case !populated && populatedRun: - // Finish the run by changing this segment. - runRange := memmap.FileRange{ - Start: r.Start + uint64(populatedRunStart*usermem.PageSize), - End: r.Start + uint64(i*usermem.PageSize), + } + // Scan to the end of this committed range. + j := i + 1 + for ; j < bufLen; j++ { + if buf[j]&0x1 == 0 { + break } - seg = f.usage.Isolate(seg, runRange) - seg.ValuePtr().knownCommitted = true - // Advance the segment only if we still - // have work to do in the context of - // the original segment from the for - // loop. Otherwise, the for loop itself - // will advance the segment - // appropriately. - if runRange.End != r.End { - seg = seg.NextSegment() + } + committedFR := memmap.FileRange{ + Start: r.Start + uint64(i*usermem.PageSize), + End: r.Start + uint64(j*usermem.PageSize), + } + // Advance seg to committedFR.Start. + for seg.Ok() && seg.End() < committedFR.Start { + seg = seg.NextSegment() + } + // Mark pages overlapping committedFR as committed. + for seg.Ok() && seg.Start() < committedFR.End { + if seg.ValuePtr().canCommit() { + seg = f.usage.Isolate(seg, committedFR) + seg.ValuePtr().knownCommitted = true + amount := seg.Range().Length() + usage.MemoryAccounting.Inc(amount, seg.ValuePtr().kind) + f.usageExpected += amount + changedAny = true } - amount := runRange.Length() - usage.MemoryAccounting.Inc(amount, val.kind) - f.usageExpected += amount - changedAny = true - populatedRun = false + seg = seg.NextSegment() } + // Continue scanning for committed pages. + i = j + 1 } // Advance r.Start. @@ -978,6 +981,9 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func( if err != nil { return err } + + // Continue with the first segment after r.End. + seg = f.usage.LowerBoundSegment(r.End) } return nil diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index ed5ae03d3..58f3d6fdd 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -39,6 +39,16 @@ var ( } ) +// getTLS returns the value of TPIDR_EL0 register. +// +//go:nosplit +func getTLS() (value uint64) + +// setTLS writes the TPIDR_EL0 value. +// +//go:nosplit +func setTLS(value uint64) + // bluepillArchEnter is called during bluepillEnter. // //go:nosplit @@ -51,6 +61,8 @@ func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { regs.Pstate = context.Pstate regs.Pstate &^= uint64(ring0.PsrFlagsClear) regs.Pstate |= ring0.KernelFlagsSet + regs.TPIDR_EL0 = getTLS() + return } @@ -65,6 +77,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { context.Pstate = regs.Pstate context.Pstate &^= uint64(ring0.PsrFlagsClear) context.Pstate |= ring0.UserFlagsSet + setTLS(regs.TPIDR_EL0) lazyVfp := c.GetLazyVFP() if lazyVfp != 0 { diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 04efa0147..09c7e88e5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -32,6 +32,18 @@ #define CONTEXT_PC 0x1B8 #define CONTEXT_R0 0xB8 +// getTLS returns the value of TPIDR_EL0 register. +TEXT ·getTLS(SB),NOSPLIT,$0-8 + MRS TPIDR_EL0, R1 + MOVD R1, ret+0(FP) + RET + +// setTLS writes the TPIDR_EL0 value. +TEXT ·setTLS(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + MSR R1, TPIDR_EL0 + RET + // See bluepill.go. TEXT ·bluepill(SB),NOSPLIT,$0 begin: diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 2f1abcb0f..d91a09de1 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -53,12 +53,6 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) -// GetTLS returns the value of TPIDR_EL0 register. -func GetTLS() (value uint64) - -// SetTLS writes the TPIDR_EL0 value. -func SetTLS(value uint64) - // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 8aabf7d0e..da9d3cf55 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -29,16 +29,6 @@ TEXT ·FlushTlbAll(SB),NOSPLIT,$0 ISB $15 RET -TEXT ·GetTLS(SB),NOSPLIT,$0-8 - MRS TPIDR_EL0, R1 - MOVD R1, ret+0(FP) - RET - -TEXT ·SetTLS(SB),NOSPLIT,$0-8 - MOVD addr+0(FP), R1 - MSR R1, TPIDR_EL0 - RET - TEXT ·CPACREL1(SB),NOSPLIT,$0-8 WORD $0xd5381041 // MRS CPACR_EL1, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 5ddcd4be5..3baad098b 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -16,6 +16,7 @@ package netlink import ( + "io" "math" "gvisor.dev/gvisor/pkg/abi/linux" @@ -748,6 +749,12 @@ func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, t buf := make([]byte, src.NumBytes()) n, err := src.CopyIn(ctx, buf) + // io.EOF can be only returned if src is a file, this means that + // sendMsg is called from splice and the error has to be ignored in + // this case. + if err == io.EOF { + err = nil + } if err != nil { // Don't partially consume messages. return 0, syserr.FromError(err) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 87e30d742..211f07947 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -587,6 +587,11 @@ func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { + // EOF can be returned only if src is a file and this means it + // is in a splice syscall and the error has to be ignored. + if err == io.EOF { + return v, nil + } return nil, tcpip.ErrBadAddress } return v, nil diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index f80011ce4..a4a76d0a3 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -573,13 +573,17 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS if dst.NumBytes() == 0 { return 0, nil } - return dst.CopyOutFrom(ctx, &EndpointReader{ + r := &EndpointReader{ Ctx: ctx, Endpoint: s.ep, NumRights: 0, Peek: false, From: nil, - }) + } + n, err := dst.CopyOutFrom(ctx, r) + // Drop control messages. + r.Control.Release(ctx) + return n, err } // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 3345124cc..678355fb9 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -267,13 +267,17 @@ func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs. if dst.NumBytes() == 0 { return 0, nil } - return dst.CopyOutFrom(ctx, &EndpointReader{ + r := &EndpointReader{ Ctx: ctx, Endpoint: s.ep, NumRights: 0, Peek: false, From: nil, - }) + } + n, err := dst.CopyOutFrom(ctx, r) + // Drop control messages. + r.Control.Release(ctx) + return n, err } // PWrite implements vfs.FileDescriptionImpl. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9feaca0da..9cd052c3d 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -1052,7 +1052,9 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) - if err != nil { + // Control messages should be released on error as well as for zero-length + // messages, which are discarded by the receiver. + if n == 0 || err != nil { controlMessages.Release(t) } return uintptr(n), err diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 6320593f0..db3d924d9 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -21,7 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/usage" ) -// Sysinfo implements the sysinfo syscall as described in man 2 sysinfo. +// Sysinfo implements Linux syscall sysinfo(2). func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index d8b8d9783..36e89700e 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -145,16 +145,6 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(file.StatusFlags()), nil, nil case linux.F_SETFL: return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint()) - case linux.F_SETPIPE_SZ: - pipefile, ok := file.Impl().(*pipe.VFSPipeFD) - if !ok { - return 0, nil, syserror.EBADF - } - n, err := pipefile.SetPipeSize(int64(args[2].Int())) - if err != nil { - return 0, nil, err - } - return uintptr(n), nil, nil case linux.F_GETOWN: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -190,6 +180,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + case linux.F_SETPIPE_SZ: + pipefile, ok := file.Impl().(*pipe.VFSPipeFD) + if !ok { + return 0, nil, syserror.EBADF + } + n, err := pipefile.SetPipeSize(int64(args[2].Int())) + if err != nil { + return 0, nil, err + } + return uintptr(n), nil, nil case linux.F_GETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index bfae6b7e9..7b33b3f59 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -1055,7 +1055,9 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) - if err != nil { + // Control messages should be released on error as well as for zero-length + // messages, which are discarded by the receiver. + if n == 0 || err != nil { controlMessages.Release(t) } return uintptr(n), err diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index bf5c1171f..035e2a6b0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -45,6 +45,9 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if count > int64(kernel.MAX_RW_COUNT) { count = int64(kernel.MAX_RW_COUNT) } + if count < 0 { + return 0, nil, syserror.EINVAL + } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { @@ -192,6 +195,9 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo if count > int64(kernel.MAX_RW_COUNT) { count = int64(kernel.MAX_RW_COUNT) } + if count < 0 { + return 0, nil, syserror.EINVAL + } // Check for invalid flags. if flags&^(linux.SPLICE_F_MOVE|linux.SPLICE_F_NONBLOCK|linux.SPLICE_F_MORE|linux.SPLICE_F_GIFT) != 0 { diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go index ab1d140d2..5ed6726ab 100644 --- a/pkg/sentry/usage/memory.go +++ b/pkg/sentry/usage/memory.go @@ -278,7 +278,7 @@ func TotalMemory(memSize, used uint64) uint64 { } if memSize < used { memSize = used - // Bump totalSize to the next largest power of 2, if one exists, so + // Bump memSize to the next largest power of 2, if one exists, so // that MemFree isn't 0. if msb := bits.MostSignificantOne64(memSize); msb < 63 { memSize = uint64(1) << (uint(msb) + 1) diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD index ba2ed1ea7..abb8c3be3 100644 --- a/pkg/shim/v2/runtimeoptions/BUILD +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -11,12 +11,12 @@ proto_library( go_library( name = "runtimeoptions", - srcs = ["runtimeoptions.go"], - visibility = ["//pkg/shim/v2:__pkg__"], - deps = [ - ":api_go_proto", - "@com_github_gogo_protobuf//proto:go_default_library", + srcs = [ + "runtimeoptions.go", + "runtimeoptions_cri.go", ], + visibility = ["//pkg/shim/v2:__pkg__"], + deps = ["@com_github_gogo_protobuf//proto:go_default_library"], ) go_test( @@ -27,6 +27,6 @@ go_test( deps = [ "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", "@com_github_containerd_typeurl//:go_default_library", - "@com_github_golang_protobuf//proto:go_default_library", + "@com_github_gogo_protobuf//proto:go_default_library", ], ) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go index aaf17b87a..072dd87f0 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions.go +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -13,18 +13,5 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package runtimeoptions contains the runtimeoptions proto. package runtimeoptions - -import ( - proto "github.com/gogo/protobuf/proto" - pb "gvisor.dev/gvisor/pkg/shim/v2/runtimeoptions/api_go_proto" -) - -type Options = pb.Options - -func init() { - // The generated proto file auto registers with "golang/protobuf/proto" - // package. However, typeurl uses "golang/gogo/protobuf/proto". So registers - // the type there too. - proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") -} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go new file mode 100644 index 000000000..e6102b4cf --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go @@ -0,0 +1,383 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + "fmt" + "io" + "reflect" + "strings" + + proto "github.com/gogo/protobuf/proto" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package + +type Options struct { + // TypeUrl specifies the type of the content inside the config file. + TypeUrl string `protobuf:"bytes,1,opt,name=type_url,json=typeUrl,proto3" json:"type_url,omitempty"` + // ConfigPath specifies the filesystem location of the config file + // used by the runtime. + ConfigPath string `protobuf:"bytes,2,opt,name=config_path,json=configPath,proto3" json:"config_path,omitempty"` +} + +func (m *Options) Reset() { *m = Options{} } +func (*Options) ProtoMessage() {} +func (*Options) Descriptor() ([]byte, []int) { return fileDescriptorApi, []int{0} } + +func (m *Options) GetTypeUrl() string { + if m != nil { + return m.TypeUrl + } + return "" +} + +func (m *Options) GetConfigPath() string { + if m != nil { + return m.ConfigPath + } + return "" +} + +func init() { + proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") +} + +func (m *Options) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalTo(dAtA) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Options) MarshalTo(dAtA []byte) (int, error) { + var i int + _ = i + var l int + _ = l + if len(m.TypeUrl) > 0 { + dAtA[i] = 0xa + i++ + i = encodeVarintApi(dAtA, i, uint64(len(m.TypeUrl))) + i += copy(dAtA[i:], m.TypeUrl) + } + if len(m.ConfigPath) > 0 { + dAtA[i] = 0x12 + i++ + i = encodeVarintApi(dAtA, i, uint64(len(m.ConfigPath))) + i += copy(dAtA[i:], m.ConfigPath) + } + return i, nil +} + +func encodeVarintApi(dAtA []byte, offset int, v uint64) int { + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return offset + 1 +} + +func (m *Options) Size() (n int) { + var l int + _ = l + l = len(m.TypeUrl) + if l > 0 { + n += 1 + l + sovApi(uint64(l)) + } + l = len(m.ConfigPath) + if l > 0 { + n += 1 + l + sovApi(uint64(l)) + } + return n +} + +func sovApi(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} + +func sozApi(x uint64) (n int) { + return sovApi(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +func (this *Options) String() string { + if this == nil { + return "nil" + } + s := strings.Join([]string{`&Options{`, + `TypeUrl:` + fmt.Sprintf("%v", this.TypeUrl) + `,`, + `ConfigPath:` + fmt.Sprintf("%v", this.ConfigPath) + `,`, + `}`, + }, "") + return s +} + +func valueToStringApi(v interface{}) string { + rv := reflect.ValueOf(v) + if rv.IsNil() { + return "nil" + } + pv := reflect.Indirect(rv).Interface() + return fmt.Sprintf("*%v", pv) +} + +func (m *Options) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Options: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Options: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TypeUrl", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.TypeUrl = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field ConfigPath", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowApi + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthApi + } + postIndex := iNdEx + intStringLen + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.ConfigPath = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipApi(dAtA[iNdEx:]) + if err != nil { + return err + } + if skippy < 0 { + return ErrInvalidLengthApi + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skipApi(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + return iNdEx, nil + case 1: + iNdEx += 8 + return iNdEx, nil + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + iNdEx += length + if length < 0 { + return 0, ErrInvalidLengthApi + } + return iNdEx, nil + case 3: + for { + var innerWire uint64 + var start int = iNdEx + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowApi + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + innerWire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + innerWireType := int(innerWire & 0x7) + if innerWireType == 4 { + break + } + next, err := skipApi(dAtA[start:]) + if err != nil { + return 0, err + } + iNdEx = start + next + } + return iNdEx, nil + case 4: + return iNdEx, nil + case 5: + iNdEx += 4 + return iNdEx, nil + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + } + panic("unreachable") +} + +var ( + ErrInvalidLengthApi = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowApi = fmt.Errorf("proto: integer overflow") +) + +func init() { proto.RegisterFile("api.proto", fileDescriptorApi) } + +var fileDescriptorApi = []byte{ + // 183 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4c, 0x2c, 0xc8, 0xd4, + 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0x4d, 0x2e, 0xca, 0xd4, 0x2b, 0x2a, 0xcd, 0x2b, 0xc9, + 0xcc, 0x4d, 0xcd, 0x2f, 0x28, 0xc9, 0xcc, 0xcf, 0x2b, 0xd6, 0x2b, 0x33, 0x94, 0xd2, 0x4d, 0xcf, + 0x2c, 0xc9, 0x28, 0x4d, 0xd2, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xcf, 0x4f, 0xcf, 0xd7, 0x07, 0xab, + 0x4e, 0x2a, 0x4d, 0x03, 0xf3, 0xc0, 0x1c, 0x30, 0x0b, 0x62, 0x8a, 0x92, 0x2b, 0x17, 0xbb, 0x3f, + 0x44, 0xb3, 0x90, 0x24, 0x17, 0x47, 0x49, 0x65, 0x41, 0x6a, 0x7c, 0x69, 0x51, 0x8e, 0x04, 0xa3, + 0x02, 0xa3, 0x06, 0x67, 0x10, 0x3b, 0x88, 0x1f, 0x5a, 0x94, 0x23, 0x24, 0xcf, 0xc5, 0x9d, 0x9c, + 0x9f, 0x97, 0x96, 0x99, 0x1e, 0x5f, 0x90, 0x58, 0x92, 0x21, 0xc1, 0x04, 0x96, 0xe5, 0x82, 0x08, + 0x05, 0x24, 0x96, 0x64, 0x38, 0xc9, 0x9c, 0x78, 0x28, 0xc7, 0x78, 0xe3, 0xa1, 0x1c, 0x43, 0xc3, + 0x23, 0x39, 0xc6, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, + 0xc2, 0x63, 0x39, 0x86, 0x24, 0x36, 0xb0, 0x5d, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, 0xff, 0x07, + 0x00, 0xf2, 0x18, 0xbe, 0x00, 0x00, 0x00, +} diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go index f4c238a00..c59a2400e 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go @@ -15,11 +15,12 @@ package runtimeoptions import ( + "bytes" "testing" shim "github.com/containerd/containerd/runtime/v1/shim/v1" "github.com/containerd/typeurl" - "github.com/golang/protobuf/proto" + "github.com/gogo/protobuf/proto" ) func TestCreateTaskRequest(t *testing.T) { @@ -32,7 +33,11 @@ func TestCreateTaskRequest(t *testing.T) { if err := proto.UnmarshalText(encodedText, got); err != nil { t.Fatalf("unable to unmarshal text: %v", err) } - t.Logf("got: %s", proto.MarshalTextString(got)) + var textBuffer bytes.Buffer + if err := proto.MarshalText(&textBuffer, got); err != nil { + t.Errorf("unable to marshal text: %v", err) + } + t.Logf("got: %s", string(textBuffer.Bytes())) // Check the options. wantOptions := &Options{} diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index d4d785cca..6f81b0164 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker { } } +// IPPayload creates a checker that checks the payload. +func IPPayload(payload []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + got := h[0].Payload() + + // cmp.Diff does not consider nil slices equal to empty slices, but we do. + if len(got) == 0 && len(payload) == 0 { + return + } + + if diff := cmp.Diff(payload, got); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + } +} + // IPv4Options returns a checker that checks the options in an IPv4 packet. func IPv4Options(want []byte) NetworkChecker { return func(t *testing.T, h []header.Network) { diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD new file mode 100644 index 000000000..ec92ed623 --- /dev/null +++ b/pkg/tcpip/link/ethernet/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "ethernet", + srcs = ["ethernet.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go new file mode 100644 index 000000000..3eef7cd56 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ethernet provides an implementation of an ethernet link endpoint that +// wraps an inner link endpoint. +package ethernet + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkDispatcher = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns an ethernet link endpoint that wraps an inner link endpoint. +func New(ep stack.LinkEndpoint) *Endpoint { + var e Endpoint + e.Endpoint.Init(ep, &e) + return &e +} + +// Endpoint is an ethernet endpoint. +// +// It adds an ethernet header to packets before sending them out through its +// inner link endpoint and consumes an ethernet header before sending the +// packet to the stack. +type Endpoint struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + return + } + + eth := header.Ethernet(hdr) + if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt) + } +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + return e.Endpoint.WritePacket(r, gso, proto, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + linkAddr := e.Endpoint.LinkAddress() + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength() +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + fields := header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: proto, + } + eth.Encode(&fields) +} diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go index 76f563811..523b0d24b 100644 --- a/pkg/tcpip/link/pipe/pipe.go +++ b/pkg/tcpip/link/pipe/pipe.go @@ -26,27 +26,23 @@ import ( var _ stack.LinkEndpoint = (*Endpoint)(nil) // New returns both ends of a new pipe. -func New(linkAddr1, linkAddr2 tcpip.LinkAddress, capabilities stack.LinkEndpointCapabilities) (*Endpoint, *Endpoint) { +func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) { ep1 := &Endpoint{ - linkAddr: linkAddr1, - capabilities: capabilities, + linkAddr: linkAddr1, } ep2 := &Endpoint{ - linkAddr: linkAddr2, - linked: ep1, - capabilities: capabilities, + linkAddr: linkAddr2, } ep1.linked = ep2 + ep2.linked = ep1 return ep1, ep2 } // Endpoint is one end of a pipe. type Endpoint struct { - capabilities stack.LinkEndpointCapabilities - linkAddr tcpip.LinkAddress - dispatcher stack.NetworkDispatcher - linked *Endpoint - onWritePacket func(*stack.PacketBuffer) + dispatcher stack.NetworkDispatcher + linked *Endpoint + linkAddr tcpip.LinkAddress } // WritePacket implements stack.LinkEndpoint. @@ -55,16 +51,11 @@ func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.Network return nil } - // The pipe endpoint will accept all multicast/broadcast link traffic and only - // unicast traffic destined to itself. - if len(e.linked.linkAddr) != 0 && - r.RemoteLinkAddress != e.linked.linkAddr && - r.RemoteLinkAddress != header.EthernetBroadcastAddress && - !header.IsMulticastEthernetAddress(r.RemoteLinkAddress) { - return nil - } - - e.linked.dispatcher.DeliverNetworkPacket(e.linkAddr, r.RemoteLinkAddress, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + // Note that the local address from the perspective of this endpoint is the + // remote address from the perspective of the other end of the pipe + // (e.linked). Similarly, the remote address from the perspective of this + // endpoint is the local address on the other end. + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), })) @@ -100,8 +91,8 @@ func (*Endpoint) MTU() uint32 { } // Capabilities implements stack.LinkEndpoint. -func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { - return e.capabilities +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 } // MaxHeaderLength implements stack.LinkEndpoint. @@ -116,7 +107,7 @@ func (e *Endpoint) LinkAddress() tcpip.LinkAddress { // ARPHardwareType implements stack.LinkEndpoint. func (*Endpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareEther + return header.ARPHardwareNone } // AddHeader implements stack.LinkEndpoint. diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 59710352b..c118a2929 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -12,6 +12,7 @@ go_test( "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index d436873b6..f20b94d97 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,11 +15,13 @@ package ip_test import ( + "strings" "testing" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -320,6 +322,7 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -342,7 +345,6 @@ func TestSourceAddressValidation(t *testing.T) { SrcAddr: src, DstAddr: localIPv6Addr, }) - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -579,6 +581,7 @@ func TestIPv4Receive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { @@ -660,6 +663,7 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: "\x0a\x00\x00\xbb", DstAddr: localIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Create the ICMP header. icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) @@ -679,6 +683,7 @@ func TestIPv4ReceiveControl(t *testing.T) { SrcAddr: localIPv4Addr, DstAddr: remoteIPv4Addr, }) + ip.SetChecksum(^ip.CalculateChecksum()) // Make payload be non-zero. for i := dataOffset; i < len(view); i++ { @@ -732,6 +737,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip1.SetChecksum(^ip1.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag1[i] = uint8(i) @@ -748,6 +755,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: localIPv4Addr, }) + ip2.SetChecksum(^ip2.CalculateChecksum()) + // Make payload be non-zero. for i := header.IPv4MinimumSize; i < totalLen; i++ { frag2[i] = uint8(i) @@ -1020,3 +1029,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer _, _ = pkt.NetworkHeader().Consume(netHdrLen) return pkt } + +func TestWriteHeaderIncludedPacket(t *testing.T) { + const ( + nicID = 1 + transportProto = 5 + + dataLen = 4 + optionsLen = 4 + ) + + dataBuf := [dataLen]byte{1, 2, 3, 4} + data := dataBuf[:] + + ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1} + ipv4Options := ipv4OptionsBuf[:] + + ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4} + ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:] + + var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte + ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:] + if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + + tests := []struct { + name string + protoFactory stack.NetworkProtocolFactory + protoNum tcpip.NetworkProtocolNumber + nicAddr tcpip.Address + remoteAddr tcpip.Address + pktGen func(*testing.T, tcpip.Address) buffer.View + checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) + expectedErr *tcpip.Error + }{ + { + name: "IPv4", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with IHL too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv4MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 1, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 too small", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + { + name: "IPv4 minimum size", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv4MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.IPFullLength(header.IPv4MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv4 with options", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) + totalLen := ipHdrLen + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv4(hdr.Prepend(ipHdrLen)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(ipHdrLen), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) + } + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv6 with extension header", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) + hdr := buffer.NewPrependable(totalLen) + if n := copy(hdr.Prepend(len(data)), data); n != len(data) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(data)) + } + if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) { + t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr)) + } + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier), + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return hdr.View() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))), + checker.IPPayload(ipv6PayloadWithExtHdr), + ) + }, + }, + { + name: "IPv6 minimum size", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip) + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv6Any { + src = localIPv6Addr + } + + netHdr := pkt.NetworkHeader() + + if len(netHdr.View()) != header.IPv6MinimumSize { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize) + } + + checker.IPv6(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv6Addr), + checker.IPFullLength(header.IPv6MinimumSize), + checker.IPPayload(nil), + ) + }, + }, + { + name: "IPv6 too small", + protoFactory: ipv6.NewProtocol, + protoNum: ipv6.ProtocolNumber, + nicAddr: localIPv6Addr, + remoteAddr: remoteIPv6Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + NextHeader: transportProto, + HopLimit: ipv6.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + return buffer.View(ip[:len(ip)-1]) + }, + expectedErr: tcpip.ErrMalformedHeader, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + subTests := []struct { + name string + srcAddr tcpip.Address + }{ + { + name: "unspecified source", + srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))), + }, + { + name: "random source", + srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))), + }, + } + + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory}, + }) + e := channel.New(1, 1280, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}}) + + r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err) + } + defer r.Release() + + if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + })); err != test.expectedErr { + t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := e.Read() + if !ok { + t.Fatal("expected a packet to be written") + } + test.checker(t, pkt.Pkt, subTest.srcAddr) + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index c5ac7b8b5..e7c58ae0a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -190,29 +190,6 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { return e.protocol.Number() } -// writePacketFragments fragments pkt and writes the results on the link -// endpoint. The IP header must already present in the original packet. The mtu -// is the maximum size of the packets. -func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer) *tcpip.Error { - networkHeader := header.IPv4(pkt.NetworkHeader().View()) - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) - - for { - fragPkt, more := buildNextFragment(&pf, networkHeader) - if err := e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pf.RemainingFragmentCount() + 1)) - return err - } - r.Stats().IP.PacketsSent.Increment() - if !more { - break - } - } - - return nil -} - func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params stack.NetworkHeaderParams) { ip := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize)) length := uint16(pkt.Size()) @@ -234,10 +211,39 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } +func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { + return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +} + +// handleFragments fragments pkt and calls the handler function on each +// fragment. It returns the number of fragments handled and the number of +// fragments left to be processed. The IP header must already be present in the +// original packet. The mtu is the maximum size of the packets. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) + networkHeader := header.IPv4(pkt.NetworkHeader().View()) + pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + + var n int + for { + fragPkt, more := buildNextFragment(&pf, networkHeader) + if err := handler(fragPkt); err != nil { + return n, pf.RemainingFragmentCount() + 1, err + } + n++ + if !more { + return n, pf.RemainingFragmentCount(), nil + } + } +} + // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -273,8 +279,18 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw if r.Loop&stack.PacketOut == 0 { return nil } - if pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) { - return e.writePacketFragments(r, gso, e.nic.MTU(), pkt) + + if e.packetMustBeFragmented(pkt, gso) { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each + // fragment one by one using WritePacket() (current strategy) or if we + // want to create a PacketBufferList from the fragments and feed it to + // WritePackets(). It'll be faster but cost more memory. + return e.nic.WritePacket(r, gso, ProtocolNumber, fragPkt) + }) + r.Stats().IP.PacketsSent.IncrementBy(uint64(sent)) + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) + return err } if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -293,9 +309,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } - for pkt := pkts.Front(); pkt != nil; { + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - pkt = pkt.Next() + if e.packetMustBeFragmented(pkt, gso) { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pkt + if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Modify the packet list in place with the new fragments. + pkts.InsertAfter(pkt, fragPkt) + pkt = fragPkt + return nil + }); err != nil { + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + } + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) + } } nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -347,30 +377,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacket writes a packet already containing a network -// header through the given route. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required // checks. h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrMalformedHeader } ip := header.IPv4(h) - if !ip.IsValid(pkt.Data.Size()) { - return tcpip.ErrInvalidOptionValue - } // Always set the total length. - ip.SetTotalLength(uint16(pkt.Data.Size())) + pktSize := pkt.Data.Size() + ip.SetTotalLength(uint16(pktSize)) // Set the source address when zero. - if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + if ip.SourceAddress() == header.IPv4Any { ip.SetSourceAddress(r.LocalAddress) } - // Set the destination. If the packet already included a destination, - // it will be part of the route. + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. ip.SetDestinationAddress(r.RemoteAddress) // Set the packet ID when zero. @@ -387,19 +414,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu ip.SetChecksum(0) ip.SetChecksum(^ip.CalculateChecksum()) - if r.Loop&stack.PacketLoop != 0 { - e.HandlePacket(r, pkt.Clone()) - } - if r.Loop&stack.PacketOut == 0 { - return nil + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader } - if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil { - r.Stats().IP.OutgoingPacketErrors.Increment() - return err - } - r.Stats().IP.PacketsSent.Increment() - return nil + return e.writePacket(r, nil /* gso */, pkt) } // HandlePacket is called by the link layer when new ipv4 packets arrive for @@ -415,6 +440,32 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } + // There has been some confusion regarding verifying checksums. We need + // just look for negative 0 (0xffff) as the checksum, as it's not possible to + // get positive 0 (0) for the checksum. Some bad implementations could get it + // when doing entry replacement in the early days of the Internet, + // however the lore that one needs to check for both persists. + // + // RFC 1624 section 1 describes the source of this confusion as: + // [the partial recalculation method described in RFC 1071] computes a + // result for certain cases that differs from the one obtained from + // scratch (one's complement of one's complement sum of the original + // fields). + // + // However RFC 1624 section 5 clarifies that if using the verification method + // "recommended by RFC 1071, it does not matter if an intermediate system + // generated a -0 instead of +0". + // + // RFC1071 page 1 specifies the verification method as: + // (3) To check a checksum, the 1's complement sum is computed over the + // same set of octets, including the checksum field. If the result + // is all 1 bits (-0 in 1's complement arithmetic), the check + // succeeds. + if h.CalculateChecksum() != 0xffff { + r.Stats().IP.MalformedPacketsReceived.Increment() + return + } + // As per RFC 1122 section 3.2.1.3: // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9916d783f..fee11bb38 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,9 @@ package ipv4_test import ( - "bytes" "context" "encoding/hex" + "fmt" "math" "net" "testing" @@ -39,6 +39,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +const extraHeaderReserve = 50 + func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, @@ -118,6 +120,7 @@ func TestIPv4Sanity(t *testing.T) { tests := []struct { name string headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool maxTotalLength uint16 transportProtocol uint8 TTL uint8 @@ -133,6 +136,14 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, + { + name: "bad header checksum", + maxTotalLength: defaultMTU, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + badHeaderChecksum: true, + shouldFail: true, + }, // The TTL tests check that we are not rejecting an incoming packet // with a zero or one TTL, which has been a point of confusion in the // past as RFC 791 says: "If this field contains the value zero, then the @@ -243,7 +254,7 @@ func TestIPv4Sanity(t *testing.T) { // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: header.IPv4EmptySubnet, NIC: nicID, }, @@ -288,6 +299,12 @@ func TestIPv4Sanity(t *testing.T) { if test.headerLength != 0 { ip.SetHeaderLength(test.headerLength) } + ip.SetChecksum(0) + ipHeaderChecksum := ip.CalculateChecksum() + if test.badHeaderChecksum { + ipHeaderChecksum += 42 + } + ip.SetChecksum(^ipHeaderChecksum) requestPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), }) @@ -369,11 +386,10 @@ func TestIPv4Sanity(t *testing.T) { // comparePayloads compared the contents of all the packets against the contents // of the source packet. -func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { - t.Helper() - // Make a complete array of the sourcePacketInfo packet. - source := header.IPv4(packets[0].NetworkHeader().View()[:header.IPv4MinimumSize]) - vv := buffer.NewVectorisedView(sourcePacketInfo.Size(), sourcePacketInfo.Views()) +func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketBuffer, mtu uint32, wantFragments []fragmentInfo, proto tcpip.TransportProtocolNumber) error { + // Make a complete array of the sourcePacket packet. + source := header.IPv4(packets[0].NetworkHeader().View()) + vv := buffer.NewVectorisedView(sourcePacket.Size(), sourcePacket.Views()) source = append(source, vv.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make @@ -382,82 +398,147 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) - var offset uint16 // Build up an array of the bytes sent. - var reassembledPayload []byte + var reassembledPayload buffer.VectorisedView for i, packet := range packets { // Confirm that the packet is valid. allBytes := buffer.NewVectorisedView(packet.Size(), packet.Views()) - ip := header.IPv4(allBytes.ToView()) - if !ip.IsValid(len(ip)) { - t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) + fragmentIPHeader := header.IPv4(allBytes.ToView()) + if !fragmentIPHeader.IsValid(len(fragmentIPHeader)) { + return fmt.Errorf("fragment #%d: IP packet is invalid:\n%s", i, hex.Dump(fragmentIPHeader)) } - if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { - t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) + if got := len(fragmentIPHeader); got > int(mtu) { + return fmt.Errorf("fragment #%d: got len(fragmentIPHeader) = %d, want <= %d", i, got, mtu) } - if got, want := len(ip), int(mtu); got > want { - t.Errorf("fragment is too large, got %d want %d", got, want) + if got := fragmentIPHeader.TransportProtocol(); got != proto { + return fmt.Errorf("fragment #%d: got fragmentIPHeader.TransportProtocol() = %d, want = %d", i, got, uint8(proto)) } - if got, want := packet.AvailableHeaderBytes(), sourcePacketInfo.AvailableHeaderBytes()-header.IPv4MinimumSize; got != want { - t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) + if got := packet.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) } - if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { - t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + if got, want := packet.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber; got != want { + return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, got, want) } - if i < len(packets)-1 { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) + if got, want := fragmentIPHeader.CalculateChecksum(), uint16(0xffff); got != want { + return fmt.Errorf("fragment #%d: got ip.CalculateChecksum() = %#x, want = %#x", i, got, want) + } + if wantFragments[i].more { + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, wantFragments[i].offset) } else { - sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) + sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, wantFragments[i].offset) } - reassembledPayload = append(reassembledPayload, ip.Payload()...) - offset += ip.TotalLength() - uint16(ip.HeaderLength()) + reassembledPayload.AppendView(packet.TransportHeader().View()) + reassembledPayload.Append(packet.Data) // Clear out the checksum and length from the ip because we can't compare // it. - sourceCopy.SetTotalLength(uint16(len(ip))) + sourceCopy.SetTotalLength(wantFragments[i].payloadSize + header.IPv4MinimumSize) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) - if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { - t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) + if diff := cmp.Diff(fragmentIPHeader[:fragmentIPHeader.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]); diff != "" { + return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } } - expected := source[source.HeaderLength():] - if !bytes.Equal(reassembledPayload, expected) { - t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) + + expected := buffer.View(source[source.HeaderLength():]) + if diff := cmp.Diff(expected, reassembledPayload.ToView()); diff != "" { + return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } + + return nil } -func TestFragmentation(t *testing.T) { - const ttl = 42 +type fragmentInfo struct { + offset uint16 + more bool + payloadSize uint16 +} - var manyPayloadViewsSizes [1000]int - for i := range manyPayloadViewsSizes { - manyPayloadViewsSizes[i] = 7 - } - fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - transportHeaderLength int - extraHeaderReserveLength int - payloadViewsSizes []int - expectedFrags int - }{ - {"No fragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, - {"No fragmentation with big header", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, - {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with gso nil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with many views", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, - {"Fragmented with many views and prependable bytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, - {"Fragmented with big header", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, - {"Fragmented with big header and prependable bytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, - {"Fragmented with MTU smaller than header and prependable bytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, - } +var fragmentationTests = []struct { + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + payloadSize int + wantFragments []fragmentInfo +}{ + { + description: "No Fragmentation", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1000, more: false}, + }, + }, + { + description: "Fragmented", + mtu: 1280, + gso: nil, + transportHeaderLength: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 744, more: false}, + }, + }, + { + description: "No fragmentation with big header", + mtu: 2000, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1100, more: false}, + }, + }, + { + description: "Fragmented with gso none", + mtu: 1280, + gso: &stack.GSO{Type: stack.GSONone}, + transportHeaderLength: 0, + payloadSize: 1400, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 144, more: false}, + }, + }, + { + description: "Fragmented with big header", + mtu: 1280, + gso: nil, + transportHeaderLength: 100, + payloadSize: 1200, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1256, more: true}, + {offset: 1256, payloadSize: 44, more: false}, + }, + }, + { + description: "Fragmented with MTU smaller than header", + mtu: 300, + gso: nil, + transportHeaderLength: 1000, + payloadSize: 500, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 280, more: true}, + {offset: 280, payloadSize: 280, more: true}, + {offset: 560, payloadSize: 280, more: true}, + {offset: 840, payloadSize: 280, more: true}, + {offset: 1120, payloadSize: 280, more: true}, + {offset: 1400, payloadSize: 100, more: false}, + }, + }, +} - for _, ft := range fragTests { +func TestFragmentationWritePacket(t *testing.T) { + const ttl = 42 + + for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) source := pkt.Clone() err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, @@ -467,17 +548,101 @@ func TestFragmentation(t *testing.T) { if err != nil { t.Fatalf("r.WritePacket(_, _, _) = %s", err) } - - if got := len(ep.WrittenPackets); got != ft.expectedFrags { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, ft.expectedFrags) + if got := len(ep.WrittenPackets); got != len(ft.wantFragments) { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, len(ft.wantFragments)) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != len(ft.wantFragments) { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, len(ft.wantFragments)) } if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - compareFragments(t, ep.WrittenPackets, source, ft.mtu) + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } +} + +func TestFragmentationWritePackets(t *testing.T) { + const ttl = 42 + writePacketsTests := []struct { + description string + insertBefore int + insertAfter int + }{ + { + description: "Single packet", + insertBefore: 0, + insertAfter: 0, + }, + { + description: "With packet before", + insertBefore: 1, + insertAfter: 0, + }, + { + description: "With packet after", + insertBefore: 0, + insertAfter: 1, + }, + { + description: "With packet before and after", + insertBefore: 1, + insertAfter: 1, + }, + } + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv4MinimumSize, []int{1}, header.IPv4ProtocolNumber) + + for _, test := range writePacketsTests { + t.Run(test.description, func(t *testing.T) { + for _, ft := range fragmentationTests { + t.Run(ft.description, func(t *testing.T) { + var pkts stack.PacketBufferList + for i := 0; i < test.insertBefore; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkts.PushBack(pkt.Clone()) + for i := 0; i < test.insertAfter; i++ { + pkts.PushBack(tinyPacket.Clone()) + } + + ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) + r := buildRoute(t, ep) + + wantTotalPackets := len(ft.wantFragments) + test.insertBefore + test.insertAfter + n, err := r.WritePackets(ft.gso, pkts, stack.NetworkHeaderParams{ + Protocol: tcp.ProtocolNumber, + TTL: ttl, + TOS: stack.DefaultTOS, + }) + if err != nil { + t.Errorf("got WritePackets(_, _, _) = (_, %s), want = (_, nil)", err) + } + if n != wantTotalPackets { + t.Errorf("got WritePackets(_, _, _) = (%d, _), want = (%d, _)", n, wantTotalPackets) + } + if got := len(ep.WrittenPackets); got != wantTotalPackets { + t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.PacketsSent.Value()); got != wantTotalPackets { + t.Errorf("got c.Route.Stats().IP.PacketsSent.Value() = %d, want = %d", got, wantTotalPackets) + } + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != 0 { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) + } + + if wantTotalPackets == 0 { + return + } + + fragments := ep.WrittenPackets[test.insertBefore : len(ft.wantFragments)+test.insertBefore] + if err := compareFragments(fragments, pkt, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) + } + }) + } }) } } @@ -534,14 +699,14 @@ func TestFragmentationErrors(t *testing.T) { t.Run(ft.description, func(t *testing.T) { ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) r := buildRoute(t, ep) - pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) if err != expectedError { - t.Errorf("got WritePacket() = %s, want = %s", err, expectedError) + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) } if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) @@ -1277,6 +1442,7 @@ func TestReceiveFragments(t *testing.T) { SrcAddr: frag.srcAddr, DstAddr: frag.dstAddr, }) + ip.SetChecksum(^ip.CalculateChecksum()) vv := hdr.View().ToVectorisedView() vv.AppendView(frag.payload) @@ -1545,6 +1711,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1588,6 +1755,7 @@ func TestPacketQueing(t *testing.T) { SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, DstAddr: host1IPv4Addr.AddressWithPrefix.Address, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), })) @@ -1633,7 +1801,7 @@ func TestPacketQueing(t *testing.T) { } s.SetRouteTable([]tcpip.Route{ - tcpip.Route{ + { Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), NIC: nicID, }, diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index a454f6c34..ead6bedcb 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -252,26 +252,29 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - it, err := ns.Options().Iter(false /* check */) - if err != nil { - // Options are not valid as per the wire format, silently drop the packet. - received.Invalid.Increment() - return - } + var sourceLinkAddr tcpip.LinkAddress + { + it, err := ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the + // packet. + received.Invalid.Increment() + return + } - sourceLinkAddr, ok := getSourceLinkAddr(it) - if !ok { - received.Invalid.Increment() - return + sourceLinkAddr, ok = getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } } - unspecifiedSource := r.RemoteAddress == header.IPv6Any - // As per RFC 4861 section 4.3, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. + unspecifiedSource := r.RemoteAddress == header.IPv6Any if len(sourceLinkAddr) == 0 { if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { received.Invalid.Increment() @@ -297,41 +300,51 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - // ICMPv6 Neighbor Solicit messages are always sent to - // specially crafted IPv6 multicast addresses. As a result, the - // route we end up with here has as its LocalAddress such a - // multicast address. It would be nonsense to claim that our - // source address is a multicast address, so we manually set - // the source address to the target address requested in the - // solicit message. Since that requires mutating the route, we - // must first clone it. - r := r.Clone() - defer r.Release() - r.LocalAddress = targetAddr - - // As per RFC 4861 section 7.2.4, if the the source of the solicitation is - // the unspecified address, the node MUST set the Solicited flag to zero and - // multicast the advertisement to the all-nodes address. - solicited := true + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST [...] and multicast the advertisement to the all-nodes address. + // + remoteAddr := r.RemoteAddress if unspecifiedSource { - solicited = false - r.RemoteAddress = header.IPv6AllNodesMulticastAddress + remoteAddr = header.IPv6AllNodesMulticastAddress + } + + // Even if we were able to receive a packet from some remote, we may not + // have a route to it - the remote may be blocked via routing rules. We must + // always consult our routing table and find a route to the remote before + // sending any packet. + r, err := e.protocol.stack.FindRoute(e.nic.ID(), targetAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + // If we cannot find a route to the destination, silently drop the packet. + return } + defer r.Release() - // If the NS has a source link-layer option, use the link address it - // specifies as the remote link address for the response instead of the - // source link address of the packet. + // If the NS has a source link-layer option, resolve the route immediately + // to avoid querying the neighbor table when the neighbor entry was updated + // as probing the neighbor table for a link address will transition the + // entry's state from stale to delay. + // + // Note, if the source link address is unspecified and this is a unicast + // solicitation, we may need to perform neighbor discovery to send the + // neighbor advertisement response. This is expected as per RFC 4861 section + // 7.2.4: + // + // Because unicast Neighbor Solicitations are not required to include a + // Source Link-Layer Address, it is possible that a node sending a + // solicited Neighbor Advertisement does not have a corresponding link- + // layer address for its neighbor in its Neighbor Cache. In such + // situations, a node will first have to use Neighbor Discovery to + // determine the link-layer address of its neighbor (i.e., send out a + // multicast Neighbor Solicitation). // - // TODO(#2401): As per RFC 4861 section 7.2.4 we should consult our link - // address cache for the right destination link address instead of manually - // patching the route with the remote link address if one is specified in a - // Source Link-Layer Address option. if len(sourceLinkAddr) != 0 { - r.RemoteLinkAddress = sourceLinkAddr + r.ResolveWith(sourceLinkAddr) } optsSerializer := header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(r.LocalLinkAddress), + header.NDPTargetLinkLayerAddressOption(e.nic.LinkAddress()), } neighborAdvertSize := header.ICMPv6NeighborAdvertMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -341,7 +354,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme packet := header.ICMPv6(pkt.TransportHeader().Push(neighborAdvertSize)) packet.SetType(header.ICMPv6NeighborAdvert) na := header.NDPNeighborAdvert(packet.NDPPayload()) - na.SetSolicitedFlag(solicited) + + // As per RFC 4861 section 7.2.4: + // + // If the source of the solicitation is the unspecified address, the node + // MUST set the Solicited flag to zero and [..]. Otherwise, the node MUST + // set the Solicited flag to one and [..]. + // + na.SetSolicitedFlag(!unspecifiedSource) na.SetOverrideFlag(true) na.SetTargetAddress(targetAddr) na.Options().Serialize(optsSerializer) @@ -419,19 +439,19 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - if len(targetLinkAddr) != 0 { - if e.nud == nil { + if e.nud == nil { + if len(targetLinkAddr) != 0 { e.linkAddrCache.AddLinkAddress(e.nic.ID(), targetAddr, targetLinkAddr) - return } - - e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ - Solicited: na.SolicitedFlag(), - Override: na.OverrideFlag(), - IsRouter: na.RouterFlag(), - }) + return } + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) + case header.ICMPv6EchoRequest: received.EchoRequest.Increment() icmpHdr, ok := pkt.TransportHeader().Consume(header.ICMPv6EchoMinimumSize) @@ -635,6 +655,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAdd r := stack.Route{ LocalAddress: localAddr, RemoteAddress: addr, + LocalLinkAddress: linkEP.LinkAddress(), RemoteLinkAddress: remoteLinkAddr, } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 3affcc4e4..8dc33c560 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -101,14 +101,19 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } -type stubNUDHandler struct{} +type stubNUDHandler struct { + probeCount int + confirmationCount int +} var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { + s.probeCount++ } -func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { + s.confirmationCount++ } func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { @@ -118,6 +123,12 @@ var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { stack.NetworkLinkEndpoint + + linkAddr tcpip.LinkAddress +} + +func (i *testInterface) LinkAddress() tcpip.LinkAddress { + return i.linkAddr } func (*testInterface) ID() tcpip.NICID { @@ -1492,3 +1503,240 @@ func TestPacketQueing(t *testing.T) { }) } } + +func TestCallsToNeighborCache(t *testing.T) { + tests := []struct { + name string + createPacket func() header.ICMPv6 + multicast bool + source tcpip.Address + destination tcpip.Address + wantProbeCount int + wantConfirmationCount int + }{ + { + name: "Unicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "The source link-layer address option SHOULD be included in unicast + // solicitations." - RFC 4861 section 4.3 + // + // A Neighbor Advertisement needs to be sent in response, but the + // Neighbor Cache shouldn't be updated since we have no useful + // information about the sender. + wantProbeCount: 0, + }, + { + name: "Unicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantProbeCount: 1, + }, + { + name: "Multicast Neighbor Solicitation without source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + // "The source link-layer address option MUST be included in multicast + // solicitations." - RFC 4861 section 4.3 + wantProbeCount: 0, + }, + { + name: "Multicast Neighbor Solicitation with source link-layer address option", + createPacket: func() header.ICMPv6 { + nsSize := header.ICMPv6NeighborSolicitMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(nsSize)) + icmp.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(icmp.NDPPayload()) + ns.SetTargetAddress(lladdr0) + ns.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.SolicitedNodeAddr(lladdr0), + wantProbeCount: 1, + }, + { + name: "Unicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: lladdr0, + // "When responding to unicast solicitations, the target link-layer + // address option can be omitted since the sender of the solicitation has + // the correct link-layer address; otherwise, it would not be able to + // send the unicast solicitation in the first place." + // - RFC 4861 section 4.4 + wantConfirmationCount: 1, + }, + { + name: "Unicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: lladdr0, + wantConfirmationCount: 1, + }, + { + name: "Multicast Neighbor Advertisement without target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + // "Target link-layer address MUST be included for multicast solicitations + // in order to avoid infinite Neighbor Solicitation "recursion" when the + // peer node does not have a cache entry to return a Neighbor + // Advertisements message." - RFC 4861 section 4.4 + wantConfirmationCount: 0, + }, + { + name: "Multicast Neighbor Advertisement with target link-layer address option", + createPacket: func() header.ICMPv6 { + naSize := header.ICMPv6NeighborAdvertMinimumSize + header.NDPLinkLayerAddressSize + icmp := header.ICMPv6(buffer.NewView(naSize)) + icmp.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(icmp.NDPPayload()) + na.SetSolicitedFlag(false) + na.SetOverrideFlag(false) + na.SetTargetAddress(lladdr1) + na.Options().Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + return icmp + }, + source: lladdr1, + destination: header.IPv6AllNodesMulticastAddress, + wantConfirmationCount: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6}, + UseNeighborCache: true, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + nudHandler := &stubNUDHandler{} + ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + defer ep.Close() + + if err := ep.Enable(); err != nil { + t.Fatalf("ep.Enable(): %s", err) + } + + r, err := s.FindRoute(nicID, lladdr0, test.source, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + // TODO(gvisor.dev/issue/4517): Remove the need for this manual patch. + r.LocalAddress = test.destination + + icmp := test.createPacket() + icmp.SetChecksum(header.ICMPv6Checksum(icmp, r.RemoteAddress, r.LocalAddress, buffer.VectorisedView{})) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.RemoteAddress, + DstAddr: r.LocalAddress, + }) + ep.HandlePacket(&r, pkt) + + // Confirm the endpoint calls the correct NUDHandler method. + if nudHandler.probeCount != test.wantProbeCount { + t.Errorf("got nudHandler.probeCount = %d, want = %d", nudHandler.probeCount, test.wantProbeCount) + } + if nudHandler.confirmationCount != test.wantConfirmationCount { + t.Errorf("got nudHandler.confirmationCount = %d, want = %d", nudHandler.confirmationCount, test.wantConfirmationCount) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 2bd8f4ece..9670696c7 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -387,7 +387,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s } func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return pkt.Size() > int(e.nic.MTU()) && (gso == nil || gso.Type == stack.GSONone) + return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) } // handleFragments fragments pkt and calls the handler function on each @@ -416,17 +416,18 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p } n++ if !more { - break + return n, pf.RemainingFragmentCount(), nil } } - - return n, 0, nil } // WritePacket writes a packet to the given destination address and protocol. func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + return e.writePacket(r, gso, pkt, params.Protocol) +} +func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error { // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) @@ -468,7 +469,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw } if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -501,21 +502,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) if e.packetMustBeFragmented(pb, gso) { - current := pb - _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + // Keep track of the packet that is about to be fragmented so it can be + // removed once the fragmentation is done. + originalPkt := pb + if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. - pkts.InsertAfter(current, fragPkt) - current = current.Next() + pkts.InsertAfter(pb, fragPkt) + pb = fragPkt return nil - }) - if err != nil { + }); err != nil { r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) return 0, err } - // The fragmented packet can be released. The rest of the packets can be - // processed. - pkts.Remove(pb) - pb = current + // Remove the packet that was just fragmented and process the rest. + pkts.Remove(originalPkt) } } @@ -569,11 +569,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet -// supported by IPv6. -func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { - // TODO(b/146666412): Support IPv6 header-included packets. - return tcpip.ErrNotSupported +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { + // The packet already has an IP header, but there are a few required checks. + h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return tcpip.ErrMalformedHeader + } + ip := header.IPv6(h) + + // Always set the payload length. + pktSize := pkt.Data.Size() + ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize)) + + // Set the source address when zero. + if ip.SourceAddress() == header.IPv6Any { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, it will + // be part of the route anyways. + ip.SetDestinationAddress(r.RemoteAddress) + + // Populate the packet buffer's network header and don't allow an invalid + // packet to be sent. + // + // Note that parsing only makes sure that the packet is well formed as per the + // wire format. We also want to check if the header's fields are valid before + // sending the packet. + proto, _, _, _, ok := parse.IPv6(pkt) + if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) { + return tcpip.ErrMalformedHeader + } + + return e.writePacket(r, nil /* gso */, pkt, proto) } // HandlePacket is called by the link layer when new ipv6 packets arrive for diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index bee18d1a8..297868f24 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -49,6 +49,8 @@ const ( fragmentExtHdrID = uint8(header.IPv6FragmentExtHdrIdentifier) destinationExtHdrID = uint8(header.IPv6DestinationOptionsExtHdrIdentifier) noNextHdrID = uint8(header.IPv6NoNextHeaderIdentifier) + + extraHeaderReserve = 50 ) // testReceiveICMP tests receiving an ICMP packet from src to dst. want is the @@ -181,6 +183,9 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB return fmt.Errorf("fragment #%d: fragmentIPHeader mismatch (-want +got):\n%s", i, diff) } + if got := fragment.AvailableHeaderBytes(); got != extraHeaderReserve { + return fmt.Errorf("fragment #%d: got packet.AvailableHeaderBytes() = %d, want = %d", i, got, extraHeaderReserve) + } if fragment.NetworkProtocolNumber != sourcePacket.NetworkProtocolNumber { return fmt.Errorf("fragment #%d: got fragment.NetworkProtocolNumber = %d, want = %d", i, fragment.NetworkProtocolNumber, sourcePacket.NetworkProtocolNumber) } @@ -208,8 +213,7 @@ func compareFragments(packets []*stack.PacketBuffer, sourcePacket *stack.PacketB reassembledPayload.Append(fragment.Data) } - result := reassembledPayload.ToView() - if diff := cmp.Diff(result, buffer.View(source[sourceIPHeadersLen:])); diff != "" { + if diff := cmp.Diff(buffer.View(source[sourceIPHeadersLen:]), reassembledPayload.ToView()); diff != "" { return fmt.Errorf("reassembledPayload mismatch (-want +got):\n%s", diff) } @@ -2217,24 +2221,19 @@ type fragmentInfo struct { payloadSize uint16 } -type fragmentationTestCase struct { +var fragmentationTests = []struct { description string mtu uint32 gso *stack.GSO transHdrLen int - extraHdrLen int payloadSize int wantFragments []fragmentInfo - expectedFrags int -} - -var fragmentationTests = []fragmentationTestCase{ +}{ { description: "No Fragmentation", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1000, more: false}, @@ -2243,9 +2242,8 @@ var fragmentationTests = []fragmentationTestCase{ { description: "Fragmented", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 2000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2255,20 +2253,18 @@ var fragmentationTests = []fragmentationTestCase{ { description: "No fragmentation with big header", mtu: 2000, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1000, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1100, more: false}, }, }, { - description: "Fragmented with gso nil", + description: "Fragmented with gso none", mtu: 1280, - gso: nil, + gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1400, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, @@ -2278,30 +2274,17 @@ var fragmentationTests = []fragmentationTestCase{ { description: "Fragmented with big header", mtu: 1280, - gso: &stack.GSO{}, + gso: nil, transHdrLen: 100, - extraHdrLen: header.IPv6MinimumSize, payloadSize: 1200, wantFragments: []fragmentInfo{ {offset: 0, payloadSize: 1240, more: true}, {offset: 154, payloadSize: 76, more: false}, }, }, - { - description: "Fragmented with big header and prependable bytes", - mtu: 1280, - gso: &stack.GSO{}, - transHdrLen: 20, - extraHdrLen: header.IPv6MinimumSize + 66, - payloadSize: 1500, - wantFragments: []fragmentInfo{ - {offset: 0, payloadSize: 1240, more: true}, - {offset: 154, payloadSize: 296, more: false}, - }, - }, } -func TestFragmentation(t *testing.T) { +func TestFragmentationWritePacket(t *testing.T) { const ( ttl = 42 tos = stack.DefaultTOS @@ -2310,7 +2293,7 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragmentationTests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt.Clone() ep := testutil.NewMockLinkEndpoint(ft.mtu, nil, math.MaxInt32) r := buildRoute(t, ep) @@ -2331,10 +2314,8 @@ func TestFragmentation(t *testing.T) { if got := r.Stats().IP.OutgoingPacketErrors.Value(); got != 0 { t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = 0", got) } - if len(ep.WrittenPackets) > 0 { - if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { - t.Error(err) - } + if err := compareFragments(ep.WrittenPackets, source, ft.mtu, ft.wantFragments, tcp.ProtocolNumber); err != nil { + t.Error(err) } }) } @@ -2368,7 +2349,7 @@ func TestFragmentationWritePackets(t *testing.T) { insertAfter: 1, }, } - tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) + tinyPacket := testutil.MakeRandPkt(header.TCPMinimumSize, extraHeaderReserve+header.IPv6MinimumSize, []int{1}, header.IPv6ProtocolNumber) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -2378,7 +2359,7 @@ func TestFragmentationWritePackets(t *testing.T) { for i := 0; i < test.insertBefore; i++ { pkts.PushBack(tinyPacket.Clone()) } - pkt := testutil.MakeRandPkt(ft.transHdrLen, ft.extraHdrLen, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) source := pkt pkts.PushBack(pkt.Clone()) for i := 0; i < test.insertAfter; i++ { @@ -2480,7 +2461,7 @@ func TestFragmentationErrors(t *testing.T) { for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - pkt := testutil.MakeRandPkt(ft.transHdrLen, header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) + pkt := testutil.MakeRandPkt(ft.transHdrLen, extraHeaderReserve+header.IPv6MinimumSize, []int{ft.payloadSize}, header.IPv6ProtocolNumber) ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 9033a9ed5..ac20f217e 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "context" "strings" "testing" "time" @@ -398,16 +399,17 @@ func TestNeighorSolicitationResponse(t *testing.T) { } tests := []struct { - name string - nsOpts header.NDPOptionsSerializer - nsSrcLinkAddr tcpip.LinkAddress - nsSrc tcpip.Address - nsDst tcpip.Address - nsInvalid bool - naDstLinkAddr tcpip.LinkAddress - naSolicited bool - naSrc tcpip.Address - naDst tcpip.Address + name string + nsOpts header.NDPOptionsSerializer + nsSrcLinkAddr tcpip.LinkAddress + nsSrc tcpip.Address + nsDst tcpip.Address + nsInvalid bool + naDstLinkAddr tcpip.LinkAddress + naSolicited bool + naSrc tcpip.Address + naDst tcpip.Address + performsLinkResolution bool }{ { name: "Unspecified source to solicited-node multicast destination", @@ -416,7 +418,7 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsSrc: header.IPv6Any, nsDst: nicAddrSNMC, nsInvalid: false, - naDstLinkAddr: remoteLinkAddr0, + naDstLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), naSolicited: false, naSrc: nicAddr, naDst: header.IPv6AllNodesMulticastAddress, @@ -449,7 +451,6 @@ func TestNeighorSolicitationResponse(t *testing.T) { nsDst: nicAddr, nsInvalid: true, }, - { name: "Specified source with 1 source ll to multicast destination", nsOpts: header.NDPOptionsSerializer{ @@ -509,6 +510,10 @@ func TestNeighorSolicitationResponse(t *testing.T) { naSolicited: true, naSrc: nicAddr, naDst: remoteAddr, + // Since we send a unicast solicitations to a node without an entry for + // the remote, the node needs to perform neighbor discovery to get the + // remote's link address to send the advertisement response. + performsLinkResolution: true, }, { name: "Specified source with 1 source ll to unicast destination", @@ -615,11 +620,78 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } - p, got := e.Read() + if test.performsLinkResolution { + p, got := e.ReadContext(context.Background()) + if !got { + t.Fatal("expected an NDP NS response") + } + + if p.Route.LocalAddress != nicAddr { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, nicAddr) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + respNSDst := header.SolicitedNodeAddr(test.nsSrc) + if p.Route.RemoteAddress != respNSDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) + } + if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + } + + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(nicAddr), + checker.DstAddr(respNSDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNS( + checker.NDPNSTargetAddress(test.nsSrc), + checker.NDPNSOptions([]header.NDPOption{ + header.NDPSourceLinkLayerAddressOption(nicLinkAddr), + }), + )) + + ser := header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + } + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + ser.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + na := header.NDPNeighborAdvert(pkt.NDPPayload()) + na.SetSolicitedFlag(true) + na.SetOverrideFlag(true) + na.SetTargetAddress(test.nsSrc) + na.Options().Serialize(ser) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, nicAddr, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: test.nsSrc, + DstAddr: nicAddr, + }) + e.InjectLinkAddr(ProtocolNumber, "", stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) + } + + p, got := e.ReadContext(context.Background()) if !got { t.Fatal("expected an NDP NA response") } + if p.Route.LocalAddress != test.naSrc { + t.Errorf("got p.Route.LocalAddress = %s, want = %s", p.Route.LocalAddress, test.naSrc) + } + if p.Route.LocalLinkAddress != nicLinkAddr { + t.Errorf("p.Route.LocalLinkAddress = %s, want = %s", p.Route.LocalLinkAddress, nicLinkAddr) + } + if p.Route.RemoteAddress != test.naDst { + t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) + } if p.Route.RemoteLinkAddress != test.naDstLinkAddr { t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) } diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index eba97334e..d09ebe7fa 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -123,6 +123,7 @@ go_test( "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/ports", diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 4d69a4de1..be61a21af 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -406,9 +406,9 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla // INCOMPLETE state." - RFC 4861 section 7.2.5 case Reachable, Stale, Delay, Probe: - sameLinkAddr := e.neigh.LinkAddr == linkAddr + isLinkAddrDifferent := len(linkAddr) != 0 && e.neigh.LinkAddr != linkAddr - if !sameLinkAddr { + if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { e.dispatchChangeEventLocked(Stale) @@ -431,7 +431,7 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla } } - if flags.Solicited && (flags.Override || sameLinkAddr) { + if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { if e.neigh.State != Reachable { e.dispatchChangeEventLocked(Reachable) } diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index e79abebca..3ee2a3b31 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -83,15 +83,18 @@ func eventDiffOptsWithSort() []cmp.Option { // | Reachable | Stale | Reachable timer expired | | Changed | // | Reachable | Stale | Probe or confirmation w/ different address | | Changed | // | Stale | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Stale | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Stale | Stale | Override confirmation | Update LinkAddr | Changed | // | Stale | Stale | Probe w/ different address | Update LinkAddr | Changed | // | Stale | Delay | Packet sent | | Changed | // | Delay | Reachable | Upper-layer confirmation | | Changed | // | Delay | Reachable | Solicited override confirmation | Update LinkAddr | Changed | +// | Delay | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Delay | Stale | Probe or confirmation w/ different address | | Changed | // | Delay | Probe | Delay timer expired | Send probe | Changed | // | Probe | Reachable | Solicited override confirmation | Update LinkAddr | Changed | // | Probe | Reachable | Solicited confirmation w/ same address | Notify wakers | Changed | +// | Probe | Reachable | Solicited confirmation w/o address | Notify wakers | Changed | // | Probe | Stale | Probe or confirmation w/ different address | | Changed | // | Probe | Probe | Retransmit timer expired | Send probe | Changed | // | Probe | Failed | Max probes sent without reply | Notify wakers | Removed | @@ -1370,6 +1373,77 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + e, nudDisp, linkRes, _ := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -1752,6 +1826,100 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { nudDisp.mu.Unlock() } +func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + c.MaxMulticastProbes = 1 + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() e, nudDisp, linkRes, _ := entryTestSetup(c) @@ -2665,6 +2833,115 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin nudDisp.mu.Unlock() } +func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Stale to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.Advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // The first probe is caused by the Unknown-to-Incomplete transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.handleConfirmationLocked("" /* linkAddr */, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + + clock.Advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 8828cc5fe..dcd4319bf 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -686,7 +685,9 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // packet to forward. fwdPkt := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: int(n.LinkEndpoint.MaxHeaderLength()), - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + // We need to do a deep copy of the IP packet because WritePacket (and + // friends) take ownership of the packet buffer, but we do not own it. + Data: PayloadSince(pkt.NetworkHeader()).ToVectorisedView(), }) // TODO(b/143425874) Decrease the TTL field in forwarded packets. diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 105583c49..7f54a6de8 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -311,11 +311,25 @@ func (h PacketHeader) Consume(size int) (v buffer.View, consumed bool) { } // PayloadSince returns packet payload starting from and including a particular -// header. This method isn't optimized and should be used in test only. +// header. +// +// The returned View is owned by the caller - its backing buffer is separate +// from the packet header's underlying packet buffer. func PayloadSince(h PacketHeader) buffer.View { - var v buffer.View + size := h.pk.Data.Size() + for _, hinfo := range h.pk.headers[h.typ:] { + size += len(hinfo.buf) + } + + v := make(buffer.View, 0, size) + for _, hinfo := range h.pk.headers[h.typ:] { v = append(v, hinfo.buf...) } - return append(v, h.pk.Data.ToView()...) + + for _, view := range h.pk.Data.Views() { + v = append(v, view...) + } + + return v } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 25f80c1f8..b76e2d37b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -126,6 +126,12 @@ func (r *Route) GSOMaxSize() uint32 { return 0 } +// ResolveWith immediately resolves a route with the specified remote link +// address. +func (r *Route) ResolveWith(addr tcpip.LinkAddress) { + r.RemoteLinkAddress = addr +} + // Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in // case address resolution requires blocking, e.g. wait for ARP reply. Waker is // notified when address resolution is complete (success or not). diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 38994cca1..e75f58c64 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -34,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -3498,6 +3499,52 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } } +func TestResolveWith(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol}, + }) + ep := channel.New(0, defaultMTU, "") + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + addr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address([]byte{192, 168, 1, 58}), + PrefixLen: 24, + }, + } + if err := s.AddProtocolAddress(nicID, addr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, addr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: header.IPv4EmptySubnet, NIC: nicID}}) + + remoteAddr := tcpip.Address([]byte{192, 168, 1, 59}) + r, err := s.FindRoute(unspecifiedNICID, "" /* localAddr */, remoteAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, '', %s, %d): %s", unspecifiedNICID, remoteAddr, header.IPv4ProtocolNumber, err) + } + defer r.Release() + + // Should initially require resolution. + if !r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = false, want = true") + } + + // Manually resolving the route should no longer require resolution. + r.ResolveWith("\x01") + if r.IsResolutionRequired() { + t.Fatal("got r.IsResolutionRequired() = true, want = false") + } +} + // TestRouteReleaseAfterAddrRemoval tests that releasing a Route after its // associated address is removed should not cause a panic. func TestRouteReleaseAfterAddrRemoval(t *testing.T) { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c42bb0991..d77848d61 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -111,6 +111,7 @@ var ( ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} + ErrMalformedHeader = &Error{msg: "header is malformed"} ) var messageToError map[string]*Error @@ -159,6 +160,7 @@ func StringToError(s string) *Error { ErrBroadcastDisabled, ErrNotPermitted, ErrAddressFamilyNotSupported, + ErrMalformedHeader, } messageToError = make(map[string]*Error) diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index a4f141253..34aab32d0 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -16,6 +16,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/ethernet", "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/pipe", "//pkg/tcpip/network/arp", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ffd38ee1a..0dcef7b04 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -178,19 +179,19 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr, stack.CapabilityResolutionRequired) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) + routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := routerStack.CreateNIC(routerNICID1, routerNIC1); err != nil { + if err := routerStack.CreateNIC(routerNICID1, ethernet.New(routerNIC1)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID1, err) } - if err := routerStack.CreateNIC(routerNICID2, routerNIC2); err != nil { + if err := routerStack.CreateNIC(routerNICID2, ethernet.New(routerNIC2)); err != nil { t.Fatalf("routerStack.CreateNIC(%d, _): %s", routerNICID2, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index bf3a6f6ee..6ddcda70c 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -126,12 +127,12 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr, stack.CapabilityResolutionRequired) + host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) - if err := host1Stack.CreateNIC(host1NICID, host1NIC); err != nil { + if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) } - if err := host2Stack.CreateNIC(host2NICID, host2NIC); err != nil { + if err := host2Stack.CreateNIC(host2NICID, ethernet.New(host2NIC)); err != nil { t.Fatalf("host2Stack.CreateNIC(%d, _): %s", host2NICID, err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 4f2ca7f54..f1028823b 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -80,6 +80,7 @@ func TestPingMulticastBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), @@ -250,6 +251,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { SrcAddr: remoteIPv4Addr, DstAddr: dst, }) + ip.SetChecksum(^ip.CalculateChecksum()) e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: hdr.View().ToVectorisedView(), diff --git a/runsc/BUILD b/runsc/BUILD index 33d8554af..3b91b984a 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -13,16 +13,7 @@ go_binary( "//visibility:public", ], x_defs = {"main.version": "{STABLE_VERSION}"}, - deps = [ - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/platform", - "//runsc/cmd", - "//runsc/config", - "//runsc/flag", - "//runsc/specutils", - "@com_github_google_subcommands//:go_default_library", - ], + deps = ["//runsc/cli"], ) # The runsc-race target is a race-compatible BUILD target. This must be built @@ -49,16 +40,7 @@ go_binary( "//visibility:public", ], x_defs = {"main.version": "{STABLE_VERSION}"}, - deps = [ - "//pkg/log", - "//pkg/refs", - "//pkg/sentry/platform", - "//runsc/cmd", - "//runsc/config", - "//runsc/flag", - "//runsc/specutils", - "@com_github_google_subcommands//:go_default_library", - ], + deps = ["//runsc/cli"], ) sh_test( diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 2d9517f4a..248f77c34 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -110,8 +110,8 @@ go_library( "//runsc/config", "//runsc/specutils", "//runsc/specutils/seccomp", - "@com_github_golang_protobuf//proto:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index 84c67cbc2..7076ae2e2 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -19,7 +19,7 @@ import ( "os" "syscall" - "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 82e459f46..004da5b40 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -264,10 +264,38 @@ func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Cre } cu.Add(func() { lower.DecRef(ctx) }) + // Propagate the lower layer's root's owner, group, and mode to the upper + // layer's root for consistency with VFS1. + upperRootVD := vfs.MakeVirtualDentry(upper, upper.Root()) + lowerRootVD := vfs.MakeVirtualDentry(lower, lower.Root()) + stat, err := c.k.VFS().StatAt(ctx, creds, &vfs.PathOperation{ + Root: lowerRootVD, + Start: lowerRootVD, + }, &vfs.StatOptions{ + Mask: linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE, + }) + if err != nil { + return nil, nil, err + } + err = c.k.VFS().SetStatAt(ctx, creds, &vfs.PathOperation{ + Root: upperRootVD, + Start: upperRootVD, + }, &vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: (linux.STATX_UID | linux.STATX_GID | linux.STATX_MODE) & stat.Mask, + UID: stat.UID, + GID: stat.GID, + Mode: stat.Mode, + }, + }) + if err != nil { + return nil, nil, err + } + // Configure overlay with both layers. overlayOpts.GetFilesystemOptions.InternalData = overlay.FilesystemOptions{ - UpperRoot: vfs.MakeVirtualDentry(upper, upper.Root()), - LowerRoots: []vfs.VirtualDentry{vfs.MakeVirtualDentry(lower, lower.Root())}, + UpperRoot: upperRootVD, + LowerRoots: []vfs.VirtualDentry{lowerRootVD}, } return &overlayOpts, cu.Release(), nil } diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index 8fbc3887a..56da21584 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -201,13 +201,15 @@ func LoadPaths(pid string) (map[string]string, error) { paths := make(map[string]string) scanner := bufio.NewScanner(f) for scanner.Scan() { - // Format: ID:controller1,controller2:path + // Format: ID:[name=]controller1,controller2:path // Example: 2:cpu,cpuacct:/user.slice tokens := strings.Split(scanner.Text(), ":") if len(tokens) != 3 { return nil, fmt.Errorf("invalid cgroups file, line: %q", scanner.Text()) } for _, ctrlr := range strings.Split(tokens[1], ",") { + // Remove prefix for cgroups with no controller, eg. systemd. + ctrlr = strings.TrimPrefix(ctrlr, "name=") paths[ctrlr] = tokens[2] } } @@ -237,7 +239,7 @@ func New(spec *specs.Spec) (*Cgroup, error) { var err error parents, err = LoadPaths("self") if err != nil { - return nil, fmt.Errorf("finding current cgroups: %v", err) + return nil, fmt.Errorf("finding current cgroups: %w", err) } } return &Cgroup{ @@ -276,10 +278,8 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { } return err } - if res != nil { - if err := cfg.ctrlr.set(res, path); err != nil { - return err - } + if err := cfg.ctrlr.set(res, path); err != nil { + return err } } clean.Release() @@ -304,14 +304,15 @@ func (c *Cgroup) Uninstall() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) - if err := backoff.Retry(func() error { + fn := func() error { err := syscall.Rmdir(path) if os.IsNotExist(err) { return nil } return err - }, b); err != nil { - return fmt.Errorf("removing cgroup path %q: %v", path, err) + } + if err := backoff.Retry(fn, b); err != nil { + return fmt.Errorf("removing cgroup path %q: %w", path, err) } } return nil @@ -332,7 +333,6 @@ func (c *Cgroup) Join() (func(), error) { if _, ok := controllers[ctrlr]; ok { fullPath := filepath.Join(cgroupRoot, ctrlr, path) undoPaths = append(undoPaths, fullPath) - break } } @@ -422,7 +422,7 @@ func (*noop) set(*specs.LinuxResources, string) error { type memory struct{} func (*memory) set(spec *specs.LinuxResources, path string) error { - if spec.Memory == nil { + if spec == nil || spec.Memory == nil { return nil } if err := setOptionalValueInt(path, "memory.limit_in_bytes", spec.Memory.Limit); err != nil { @@ -455,7 +455,7 @@ func (*memory) set(spec *specs.LinuxResources, path string) error { type cpu struct{} func (*cpu) set(spec *specs.LinuxResources, path string) error { - if spec.CPU == nil { + if spec == nil || spec.CPU == nil { return nil } if err := setOptionalValueUint(path, "cpu.shares", spec.CPU.Shares); err != nil { @@ -478,7 +478,7 @@ type cpuSet struct{} func (*cpuSet) set(spec *specs.LinuxResources, path string) error { // cpuset.cpus and mems are required fields, but are not set on a new cgroup. // If not set in the spec, get it from one of the ancestors cgroup. - if spec.CPU == nil || spec.CPU.Cpus == "" { + if spec == nil || spec.CPU == nil || spec.CPU.Cpus == "" { if _, err := fillFromAncestor(filepath.Join(path, "cpuset.cpus")); err != nil { return err } @@ -488,18 +488,17 @@ func (*cpuSet) set(spec *specs.LinuxResources, path string) error { } } - if spec.CPU == nil || spec.CPU.Mems == "" { + if spec == nil || spec.CPU == nil || spec.CPU.Mems == "" { _, err := fillFromAncestor(filepath.Join(path, "cpuset.mems")) return err } - mems := spec.CPU.Mems - return setValue(path, "cpuset.mems", mems) + return setValue(path, "cpuset.mems", spec.CPU.Mems) } type blockIO struct{} func (*blockIO) set(spec *specs.LinuxResources, path string) error { - if spec.BlockIO == nil { + if spec == nil || spec.BlockIO == nil { return nil } @@ -549,7 +548,7 @@ func setThrottle(path, name string, devs []specs.LinuxThrottleDevice) error { type networkClass struct{} func (*networkClass) set(spec *specs.LinuxResources, path string) error { - if spec.Network == nil { + if spec == nil || spec.Network == nil { return nil } return setOptionalValueUint32(path, "net_cls.classid", spec.Network.ClassID) @@ -558,7 +557,7 @@ func (*networkClass) set(spec *specs.LinuxResources, path string) error { type networkPrio struct{} func (*networkPrio) set(spec *specs.LinuxResources, path string) error { - if spec.Network == nil { + if spec == nil || spec.Network == nil { return nil } for _, prio := range spec.Network.Priorities { @@ -573,7 +572,7 @@ func (*networkPrio) set(spec *specs.LinuxResources, path string) error { type pids struct{} func (*pids) set(spec *specs.LinuxResources, path string) error { - if spec.Pids == nil || spec.Pids.Limit <= 0 { + if spec == nil || spec.Pids == nil || spec.Pids.Limit <= 0 { return nil } val := strconv.FormatInt(spec.Pids.Limit, 10) @@ -583,6 +582,9 @@ func (*pids) set(spec *specs.LinuxResources, path string) error { type hugeTLB struct{} func (*hugeTLB) set(spec *specs.LinuxResources, path string) error { + if spec == nil { + return nil + } for _, limit := range spec.HugepageLimits { name := fmt.Sprintf("hugetlb.%s.limit_in_bytes", limit.Pagesize) val := strconv.FormatUint(limit.Limit, 10) diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD new file mode 100644 index 000000000..32cce2a18 --- /dev/null +++ b/runsc/cli/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "cli", + srcs = ["main.go"], + visibility = [ + "//:__pkg__", + "//runsc:__pkg__", + ], + deps = [ + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/platform", + "//runsc/cmd", + "//runsc/config", + "//runsc/flag", + "//runsc/specutils", + "@com_github_google_subcommands//:go_default_library", + ], +) diff --git a/runsc/cli/main.go b/runsc/cli/main.go new file mode 100644 index 000000000..bca015db5 --- /dev/null +++ b/runsc/cli/main.go @@ -0,0 +1,256 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli is the main entrypoint for runsc. +package cli + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "os" + "os/signal" + "syscall" + "time" + + "github.com/google/subcommands" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/runsc/cmd" + "gvisor.dev/gvisor/runsc/config" + "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/specutils" +) + +var ( + // Although these flags are not part of the OCI spec, they are used by + // Docker, and thus should not be changed. + // TODO(gvisor.dev/issue/193): support systemd cgroups + systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") + showVersion = flag.Bool("version", false, "show version and exit.") + + // These flags are unique to runsc, and are used to configure parts of the + // system that are not covered by the runtime spec. + + // Debugging flags. + logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") + debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") +) + +// Main is the main entrypoint. +func Main(version string) { + // Help and flags commands are generated automatically. + help := cmd.NewHelp(subcommands.DefaultCommander) + help.Register(new(cmd.Syscalls)) + subcommands.Register(help, "") + subcommands.Register(subcommands.FlagsCommand(), "") + + // Installation helpers. + const helperGroup = "helpers" + subcommands.Register(new(cmd.Install), helperGroup) + subcommands.Register(new(cmd.Uninstall), helperGroup) + + // Register user-facing runsc commands. + subcommands.Register(new(cmd.Checkpoint), "") + subcommands.Register(new(cmd.Create), "") + subcommands.Register(new(cmd.Delete), "") + subcommands.Register(new(cmd.Do), "") + subcommands.Register(new(cmd.Events), "") + subcommands.Register(new(cmd.Exec), "") + subcommands.Register(new(cmd.Gofer), "") + subcommands.Register(new(cmd.Kill), "") + subcommands.Register(new(cmd.List), "") + subcommands.Register(new(cmd.Pause), "") + subcommands.Register(new(cmd.PS), "") + subcommands.Register(new(cmd.Restore), "") + subcommands.Register(new(cmd.Resume), "") + subcommands.Register(new(cmd.Run), "") + subcommands.Register(new(cmd.Spec), "") + subcommands.Register(new(cmd.State), "") + subcommands.Register(new(cmd.Start), "") + subcommands.Register(new(cmd.Wait), "") + + // Register internal commands with the internal group name. This causes + // them to be sorted below the user-facing commands with empty group. + // The string below will be printed above the commands. + const internalGroup = "internal use only" + subcommands.Register(new(cmd.Boot), internalGroup) + subcommands.Register(new(cmd.Debug), internalGroup) + subcommands.Register(new(cmd.Gofer), internalGroup) + subcommands.Register(new(cmd.Statefile), internalGroup) + + config.RegisterFlags() + + // All subcommands must be registered before flag parsing. + flag.Parse() + + // Are we showing the version? + if *showVersion { + // The format here is the same as runc. + fmt.Fprintf(os.Stdout, "runsc version %s\n", version) + fmt.Fprintf(os.Stdout, "spec: %s\n", specutils.Version) + os.Exit(0) + } + + // Create a new Config from the flags. + conf, err := config.NewFromFlags() + if err != nil { + cmd.Fatalf(err.Error()) + } + + // TODO(gvisor.dev/issue/193): support systemd cgroups + if *systemdCgroup { + fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") + os.Exit(1) + } + + var errorLogger io.Writer + if *logFD > -1 { + errorLogger = os.NewFile(uintptr(*logFD), "error log file") + + } else if conf.LogFilename != "" { + // We must set O_APPEND and not O_TRUNC because Docker passes + // the same log file for all commands (and also parses these + // log files), so we can't destroy them on each command. + var err error + errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err) + } + } + cmd.ErrorLogger = errorLogger + + if _, err := platform.Lookup(conf.Platform); err != nil { + cmd.Fatalf("%v", err) + } + + // Sets the reference leak check mode. Also set it in config below to + // propagate it to child processes. + refs.SetLeakMode(conf.ReferenceLeak) + + // Set up logging. + if conf.Debug { + log.SetLevel(log.Debug) + } + + // Logging will include the local date and time via the time package. + // + // On first use, time.Local initializes the local time zone, which + // involves opening tzdata files on the host. Since this requires + // opening host files, it must be done before syscall filter + // installation. + // + // Generally there will be a log message before filter installation + // that will force initialization, but force initialization here in + // case that does not occur. + _ = time.Local.String() + + subcommand := flag.CommandLine.Arg(0) + + var e log.Emitter + if *debugLogFD > -1 { + f := os.NewFile(uintptr(*debugLogFD), "debug log file") + + e = newEmitter(conf.DebugLogFormat, f) + + } else if conf.DebugLog != "" { + f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */) + if err != nil { + cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err) + } + e = newEmitter(conf.DebugLogFormat, f) + + } else { + // Stderr is reserved for the application, just discard the logs if no debug + // log is specified. + e = newEmitter("text", ioutil.Discard) + } + + if *panicLogFD > -1 || *debugLogFD > -1 { + fd := *panicLogFD + if fd < 0 { + fd = *debugLogFD + } + // Quick sanity check to make sure no other commands get passed + // a log fd (they should use log dir instead). + if subcommand != "boot" && subcommand != "gofer" { + cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand) + } + + // If we are the boot process, then we own our stdio FDs and can do what we + // want with them. Since Docker and Containerd both eat boot's stderr, we + // dup our stderr to the provided log FD so that panics will appear in the + // logs, rather than just disappear. + if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { + cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) + } + } else if conf.AlsoLogToStderr { + e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} + } + + log.SetTarget(e) + + log.Infof("***************************") + log.Infof("Args: %s", os.Args) + log.Infof("Version %s", version) + log.Infof("PID: %d", os.Getpid()) + log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid()) + log.Infof("Configuration:") + log.Infof("\t\tRootDir: %s", conf.RootDir) + log.Infof("\t\tPlatform: %v", conf.Platform) + log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay) + log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets) + log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls) + log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) + log.Infof("***************************") + + if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { + // SIGTERM is sent to all processes if a test exceeds its + // timeout and this case is handled by syscall_test_runner. + log.Warningf("Block the TERM signal. This is only safe in tests!") + signal.Ignore(syscall.SIGTERM) + } + + // Call the subcommand and pass in the configuration. + var ws syscall.WaitStatus + subcmdCode := subcommands.Execute(context.Background(), conf, &ws) + if subcmdCode == subcommands.ExitSuccess { + log.Infof("Exiting with status: %v", ws) + if ws.Signaled() { + // No good way to return it, emulate what the shell does. Maybe raise + // signal to self? + os.Exit(128 + int(ws.Signal())) + } + os.Exit(ws.ExitStatus()) + } + // Return an error that is unlikely to be used by the application. + log.Warningf("Failure to execute command, err: %v", subcmdCode) + os.Exit(128) +} + +func newEmitter(format string, logFile io.Writer) log.Emitter { + switch format { + case "text": + return log.GoogleEmitter{&log.Writer{Next: logFile}} + case "json": + return log.JSONEmitter{&log.Writer{Next: logFile}} + case "json-k8s": + return log.K8sJSONEmitter{&log.Writer{Next: logFile}} + } + cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) + panic("unreachable") +} diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index d1f2e9e6d..640de4c47 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -17,6 +17,7 @@ package cmd import ( "context" "encoding/json" + "errors" "fmt" "io/ioutil" "math/rand" @@ -36,6 +37,8 @@ import ( "gvisor.dev/gvisor/runsc/specutils" ) +var errNoDefaultInterface = errors.New("no default interface found") + // Do implements subcommands.Command for the "do" command. It sets up a simple // sandbox and executes the command inside it. See Usage() for more details. type Do struct { @@ -126,26 +129,28 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) if conf.Network == config.NetworkNone { - netns := specs.LinuxNamespace{ - Type: specs.NetworkNamespace, - } - if spec.Linux != nil { - panic("spec.Linux is not nil") - } - spec.Linux = &specs.Linux{Namespaces: []specs.LinuxNamespace{netns}} + addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace}) } else if conf.Rootless { if conf.Network == config.NetworkSandbox { - c.notifyUser("*** Warning: using host network due to --rootless ***") + c.notifyUser("*** Warning: sandbox network isn't supported with --rootless, switching to host ***") conf.Network = config.NetworkHost } } else { - clean, err := c.setupNet(cid, spec) - if err != nil { + switch clean, err := c.setupNet(cid, spec); err { + case errNoDefaultInterface: + log.Warningf("Network interface not found, using internal network") + addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace}) + conf.Network = config.NetworkHost + + case nil: + // Setup successfull. + defer clean() + + default: return Errorf("Error setting up network: %v", err) } - defer clean() } out, err := json.Marshal(spec) @@ -199,6 +204,13 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su return subcommands.ExitSuccess } +func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) { + if spec.Linux == nil { + spec.Linux = &specs.Linux{} + } + spec.Linux.Namespaces = append(spec.Linux.Namespaces, ns) +} + func (c *Do) notifyUser(format string, v ...interface{}) { if !c.quiet { fmt.Printf(format+"\n", v...) @@ -219,10 +231,14 @@ func resolvePath(path string) (string, error) { return path, nil } +// setupNet setups up the sandbox network, including the creation of a network +// namespace, and iptable rules to redirect the traffic. Returns a cleanup +// function to tear down the network. Returns errNoDefaultInterface when there +// is no network interface available to setup the network. func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) { dev, err := defaultDevice() if err != nil { - return nil, err + return nil, errNoDefaultInterface } peerIP, err := calculatePeerIP(c.ip) if err != nil { @@ -279,14 +295,11 @@ func (c *Do) setupNet(cid string, spec *specs.Spec) (func(), error) { return nil, err } - if spec.Linux == nil { - spec.Linux = &specs.Linux{} - } netns := specs.LinuxNamespace{ Type: specs.NetworkNamespace, Path: filepath.Join("/var/run/netns", cid), } - spec.Linux.Namespaces = append(spec.Linux.Namespaces, netns) + addNamespace(spec, netns) return func() { c.cleanupNet(cid, dev, resolvPath, hostnamePath, hostsPath) }, nil } diff --git a/runsc/container/container.go b/runsc/container/container.go index 63478ba8c..63f64ce6e 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -985,7 +985,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu // Start the gofer in the given namespace. log.Debugf("Starting gofer: %s %v", binPath, args) if err := specutils.StartInNS(cmd, nss); err != nil { - return nil, nil, fmt.Errorf("Gofer: %v", err) + return nil, nil, fmt.Errorf("gofer: %v", err) } log.Infof("Gofer started, PID: %d", cmd.Process.Pid) c.GoferPid = cmd.Process.Pid diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 1f8e277cc..cc188f45b 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -2362,12 +2362,12 @@ func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte, } // executeSync synchronously executes a new process. -func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { - pid, err := cont.Execute(args) +func (c *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { + pid, err := c.Execute(args) if err != nil { return 0, fmt.Errorf("error executing: %v", err) } - ws, err := cont.WaitPID(pid) + ws, err := c.WaitPID(pid) if err != nil { return 0, fmt.Errorf("error waiting: %v", err) } diff --git a/runsc/main.go b/runsc/main.go index ed244c4ba..4ce5ebee9 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,245 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary runsc is an implementation of the Open Container Initiative Runtime -// that runs applications inside a sandbox. +// Binary runsc implements the OCI runtime interface. package main import ( - "context" - "fmt" - "io" - "io/ioutil" - "os" - "os/signal" - "syscall" - "time" - - "github.com/google/subcommands" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" - "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/runsc/cmd" - "gvisor.dev/gvisor/runsc/config" - "gvisor.dev/gvisor/runsc/flag" - "gvisor.dev/gvisor/runsc/specutils" -) - -var ( - // Although these flags are not part of the OCI spec, they are used by - // Docker, and thus should not be changed. - // TODO(gvisor.dev/issue/193): support systemd cgroups - systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") - showVersion = flag.Bool("version", false, "show version and exit.") - - // These flags are unique to runsc, and are used to configure parts of the - // system that are not covered by the runtime spec. - - // Debugging flags. - logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") - debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") + "gvisor.dev/gvisor/runsc/cli" ) func main() { - // Help and flags commands are generated automatically. - help := cmd.NewHelp(subcommands.DefaultCommander) - help.Register(new(cmd.Syscalls)) - subcommands.Register(help, "") - subcommands.Register(subcommands.FlagsCommand(), "") - - // Installation helpers. - const helperGroup = "helpers" - subcommands.Register(new(cmd.Install), helperGroup) - subcommands.Register(new(cmd.Uninstall), helperGroup) - - // Register user-facing runsc commands. - subcommands.Register(new(cmd.Checkpoint), "") - subcommands.Register(new(cmd.Create), "") - subcommands.Register(new(cmd.Delete), "") - subcommands.Register(new(cmd.Do), "") - subcommands.Register(new(cmd.Events), "") - subcommands.Register(new(cmd.Exec), "") - subcommands.Register(new(cmd.Gofer), "") - subcommands.Register(new(cmd.Kill), "") - subcommands.Register(new(cmd.List), "") - subcommands.Register(new(cmd.Pause), "") - subcommands.Register(new(cmd.PS), "") - subcommands.Register(new(cmd.Restore), "") - subcommands.Register(new(cmd.Resume), "") - subcommands.Register(new(cmd.Run), "") - subcommands.Register(new(cmd.Spec), "") - subcommands.Register(new(cmd.State), "") - subcommands.Register(new(cmd.Start), "") - subcommands.Register(new(cmd.Wait), "") - - // Register internal commands with the internal group name. This causes - // them to be sorted below the user-facing commands with empty group. - // The string below will be printed above the commands. - const internalGroup = "internal use only" - subcommands.Register(new(cmd.Boot), internalGroup) - subcommands.Register(new(cmd.Debug), internalGroup) - subcommands.Register(new(cmd.Gofer), internalGroup) - subcommands.Register(new(cmd.Statefile), internalGroup) - - config.RegisterFlags() - - // All subcommands must be registered before flag parsing. - flag.Parse() - - // Are we showing the version? - if *showVersion { - // The format here is the same as runc. - fmt.Fprintf(os.Stdout, "runsc version %s\n", version) - fmt.Fprintf(os.Stdout, "spec: %s\n", specutils.Version) - os.Exit(0) - } - - // Create a new Config from the flags. - conf, err := config.NewFromFlags() - if err != nil { - cmd.Fatalf(err.Error()) - } - - // TODO(gvisor.dev/issue/193): support systemd cgroups - if *systemdCgroup { - fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") - os.Exit(1) - } - - var errorLogger io.Writer - if *logFD > -1 { - errorLogger = os.NewFile(uintptr(*logFD), "error log file") - - } else if conf.LogFilename != "" { - // We must set O_APPEND and not O_TRUNC because Docker passes - // the same log file for all commands (and also parses these - // log files), so we can't destroy them on each command. - var err error - errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err) - } - } - cmd.ErrorLogger = errorLogger - - if _, err := platform.Lookup(conf.Platform); err != nil { - cmd.Fatalf("%v", err) - } - - // Sets the reference leak check mode. Also set it in config below to - // propagate it to child processes. - refs.SetLeakMode(conf.ReferenceLeak) - - // Set up logging. - if conf.Debug { - log.SetLevel(log.Debug) - } - - // Logging will include the local date and time via the time package. - // - // On first use, time.Local initializes the local time zone, which - // involves opening tzdata files on the host. Since this requires - // opening host files, it must be done before syscall filter - // installation. - // - // Generally there will be a log message before filter installation - // that will force initialization, but force initialization here in - // case that does not occur. - _ = time.Local.String() - - subcommand := flag.CommandLine.Arg(0) - - var e log.Emitter - if *debugLogFD > -1 { - f := os.NewFile(uintptr(*debugLogFD), "debug log file") - - e = newEmitter(conf.DebugLogFormat, f) - - } else if conf.DebugLog != "" { - f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */) - if err != nil { - cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err) - } - e = newEmitter(conf.DebugLogFormat, f) - - } else { - // Stderr is reserved for the application, just discard the logs if no debug - // log is specified. - e = newEmitter("text", ioutil.Discard) - } - - if *panicLogFD > -1 || *debugLogFD > -1 { - fd := *panicLogFD - if fd < 0 { - fd = *debugLogFD - } - // Quick sanity check to make sure no other commands get passed - // a log fd (they should use log dir instead). - if subcommand != "boot" && subcommand != "gofer" { - cmd.Fatalf("flags --debug-log-fd and --panic-log-fd should only be passed to 'boot' and 'gofer' command, but was passed to %q", subcommand) - } - - // If we are the boot process, then we own our stdio FDs and can do what we - // want with them. Since Docker and Containerd both eat boot's stderr, we - // dup our stderr to the provided log FD so that panics will appear in the - // logs, rather than just disappear. - if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { - cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) - } - } else if conf.AlsoLogToStderr { - e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} - } - - log.SetTarget(e) - - log.Infof("***************************") - log.Infof("Args: %s", os.Args) - log.Infof("Version %s", version) - log.Infof("PID: %d", os.Getpid()) - log.Infof("UID: %d, GID: %d", os.Getuid(), os.Getgid()) - log.Infof("Configuration:") - log.Infof("\t\tRootDir: %s", conf.RootDir) - log.Infof("\t\tPlatform: %v", conf.Platform) - log.Infof("\t\tFileAccess: %v, overlay: %t", conf.FileAccess, conf.Overlay) - log.Infof("\t\tNetwork: %v, logging: %t", conf.Network, conf.LogPackets) - log.Infof("\t\tStrace: %t, max size: %d, syscalls: %s", conf.Strace, conf.StraceLogSize, conf.StraceSyscalls) - log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) - log.Infof("***************************") - - if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { - // SIGTERM is sent to all processes if a test exceeds its - // timeout and this case is handled by syscall_test_runner. - log.Warningf("Block the TERM signal. This is only safe in tests!") - signal.Ignore(syscall.SIGTERM) - } - - // Call the subcommand and pass in the configuration. - var ws syscall.WaitStatus - subcmdCode := subcommands.Execute(context.Background(), conf, &ws) - if subcmdCode == subcommands.ExitSuccess { - log.Infof("Exiting with status: %v", ws) - if ws.Signaled() { - // No good way to return it, emulate what the shell does. Maybe raise - // signal to self? - os.Exit(128 + int(ws.Signal())) - } - os.Exit(ws.ExitStatus()) - } - // Return an error that is unlikely to be used by the application. - log.Warningf("Failure to execute command, err: %v", subcmdCode) - os.Exit(128) -} - -func newEmitter(format string, logFile io.Writer) log.Emitter { - switch format { - case "text": - return log.GoogleEmitter{&log.Writer{Next: logFile}} - case "json": - return log.JSONEmitter{&log.Writer{Next: logFile}} - case "json-k8s": - return log.K8sJSONEmitter{&log.Writer{Next: logFile}} - } - cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) - panic("unreachable") + cli.Main(version) } diff --git a/shim/v1/BUILD b/shim/v1/BUILD index 4c9e2c2c6..3614a67d1 100644 --- a/shim/v1/BUILD +++ b/shim/v1/BUILD @@ -4,27 +4,10 @@ package(licenses = ["notice"]) go_binary( name = "gvisor-containerd-shim", - srcs = [ - "api.go", - "config.go", - "main.go", - ], + srcs = ["main.go"], static = True, visibility = [ "//visibility:public", ], - deps = [ - "//pkg/shim/runsc", - "//pkg/shim/v1/shim", - "@com_github_burntsushi_toml//:go_default_library", - "@com_github_containerd_containerd//events:go_default_library", - "@com_github_containerd_containerd//namespaces:go_default_library", - "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", - "@com_github_containerd_containerd//sys:go_default_library", - "@com_github_containerd_containerd//sys/reaper:go_default_library", - "@com_github_containerd_ttrpc//:go_default_library", - "@com_github_containerd_typeurl//:go_default_library", - "@com_github_gogo_protobuf//types:go_default_library", - "@org_golang_x_sys//unix:go_default_library", - ], + deps = ["//shim/v1/cli"], ) diff --git a/shim/v1/cli/BUILD b/shim/v1/cli/BUILD new file mode 100644 index 000000000..0bbdc4add --- /dev/null +++ b/shim/v1/cli/BUILD @@ -0,0 +1,30 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "cli", + srcs = [ + "api.go", + "cli.go", + "config.go", + ], + visibility = [ + "//:__pkg__", + "//shim/v1:__pkg__", + ], + deps = [ + "//pkg/shim/runsc", + "//pkg/shim/v1/shim", + "@com_github_burntsushi_toml//:go_default_library", + "@com_github_containerd_containerd//events:go_default_library", + "@com_github_containerd_containerd//namespaces:go_default_library", + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_containerd//sys:go_default_library", + "@com_github_containerd_containerd//sys/reaper:go_default_library", + "@com_github_containerd_ttrpc//:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_gogo_protobuf//types:go_default_library", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/shim/v1/api.go b/shim/v1/cli/api.go index 2444d23f1..050793094 100644 --- a/shim/v1/api.go +++ b/shim/v1/cli/api.go @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package cli import ( shim "github.com/containerd/containerd/runtime/v1/shim/v1" diff --git a/shim/v1/cli/cli.go b/shim/v1/cli/cli.go new file mode 100644 index 000000000..1a502eabd --- /dev/null +++ b/shim/v1/cli/cli.go @@ -0,0 +1,267 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli defines the command line interface for the V1 shim. +package cli + +import ( + "bytes" + "context" + "flag" + "fmt" + "log" + "net" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "syscall" + + "github.com/containerd/containerd/events" + "github.com/containerd/containerd/namespaces" + "github.com/containerd/containerd/sys" + "github.com/containerd/containerd/sys/reaper" + "github.com/containerd/ttrpc" + "github.com/containerd/typeurl" + "github.com/gogo/protobuf/types" + "golang.org/x/sys/unix" + + "gvisor.dev/gvisor/pkg/shim/runsc" + "gvisor.dev/gvisor/pkg/shim/v1/shim" +) + +var ( + debugFlag bool + namespaceFlag string + socketFlag string + addressFlag string + workdirFlag string + runtimeRootFlag string + containerdBinaryFlag string + shimConfigFlag string +) + +// Containerd defaults to runc, unless another runtime is explicitly specified. +// We keep the same default to make the default behavior consistent. +const defaultRoot = "/run/containerd/runc" + +func init() { + flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") + flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") + flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") + flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") + flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") + flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime") + + // Currently, the `containerd publish` utility is embedded in the + // daemon binary. The daemon invokes `containerd-shim + // -containerd-binary ...` with its own os.Executable() path. + flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)") + flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file") +} + +// Main is the main entrypoint. +func Main() { + flag.Parse() + + // This is a hack. Exec current process to run standard containerd-shim + // if runtime root is not `runsc`. We don't need this for shim v2 api. + if filepath.Base(runtimeRootFlag) != "runsc" { + if err := executeRuncShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } + } + + // Run regular shim if needed. + if err := executeShim(); err != nil { + fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) + os.Exit(1) + } +} + +// executeRuncShim execs current process to a containerd-shim process and +// retains all flags and envs. +func executeRuncShim() error { + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + shimPath := c.RuncShim + if shimPath == "" { + shimPath, err = exec.LookPath("containerd-shim") + if err != nil { + return fmt.Errorf("lookup containerd-shim failed: %w", err) + } + } + + args := append([]string{shimPath}, os.Args[1:]...) + if err := syscall.Exec(shimPath, args, os.Environ()); err != nil { + return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err) + } + return nil +} + +func executeShim() error { + // start handling signals as soon as possible so that things are + // properly reaped or if runtime exits before we hit the handler. + signals, err := setupSignals() + if err != nil { + return err + } + path, err := os.Getwd() + if err != nil { + return err + } + server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) + if err != nil { + return fmt.Errorf("failed creating server: %w", err) + } + c, err := loadConfig(shimConfigFlag) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to load shim config: %w", err) + } + sv, err := shim.NewService( + shim.Config{ + Path: path, + Namespace: namespaceFlag, + WorkDir: workdirFlag, + RuntimeRoot: runtimeRootFlag, + RunscConfig: c.RunscConfig, + }, + &remoteEventsPublisher{address: addressFlag}, + ) + if err != nil { + return err + } + registerShimService(server, sv) + if err := serve(server, socketFlag); err != nil { + return err + } + return handleSignals(signals, server, sv) +} + +// serve serves the ttrpc API over a unix socket at the provided path this +// function does not block. +func serve(server *ttrpc.Server, path string) error { + var ( + l net.Listener + err error + ) + if path == "" { + l, err = net.FileListener(os.NewFile(3, "socket")) + path = "[inherited from parent]" + } else { + if len(path) > 106 { + return fmt.Errorf("%q: unix socket path too long (> 106)", path) + } + l, err = net.Listen("unix", "\x00"+path) + } + if err != nil { + return err + } + go func() { + defer l.Close() + err := server.Serve(context.Background(), l) + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + log.Fatalf("ttrpc server failure: %v", err) + } + }() + return nil +} + +// setupSignals creates a new signal handler for all signals and sets the shim +// as a sub-reaper so that the container processes are reparented. +func setupSignals() (chan os.Signal, error) { + signals := make(chan os.Signal, 32) + signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE) + // make sure runc is setup to use the monitor for waiting on processes. + // TODO(random-liu): Move shim/reaper.go to a separate package. + runsc.Monitor = reaper.Default + // Set the shim as the subreaper for all orphaned processes created by + // the container. + if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { + return nil, err + } + return signals, nil +} + +func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error { + var ( + termOnce sync.Once + done = make(chan struct{}) + ) + + for { + select { + case <-done: + return nil + case s := <-signals: + switch s { + case unix.SIGCHLD: + if _, err := sys.Reap(false); err != nil { + log.Printf("reap error: %v", err) + } + case unix.SIGTERM, unix.SIGINT: + go termOnce.Do(func() { + ctx := context.TODO() + if err := server.Shutdown(ctx); err != nil { + log.Printf("failed to shutdown server: %v", err) + } + // Ensure our child is dead if any. + sv.Kill(ctx, &KillRequest{ + Signal: uint32(syscall.SIGKILL), + All: true, + }) + sv.Delete(context.Background(), &types.Empty{}) + close(done) + }) + case unix.SIGPIPE: + } + } + } +} + +type remoteEventsPublisher struct { + address string +} + +func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { + ns, _ := namespaces.Namespace(ctx) + encoded, err := typeurl.MarshalAny(event) + if err != nil { + return err + } + data, err := encoded.Marshal() + if err != nil { + return err + } + cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns) + cmd.Stdin = bytes.NewReader(data) + c, err := reaper.Default.Start(cmd) + if err != nil { + return err + } + status, err := reaper.Default.Wait(cmd, c) + if err != nil { + return fmt.Errorf("failed to publish event: %w", err) + } + if status != 0 { + return fmt.Errorf("failed to publish event: status %d", status) + } + return nil +} diff --git a/shim/v1/config.go b/shim/v1/cli/config.go index a72cc7754..1be9597ed 100644 --- a/shim/v1/config.go +++ b/shim/v1/cli/config.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package cli import "github.com/BurntSushi/toml" diff --git a/shim/v1/main.go b/shim/v1/main.go index 3159923af..11ff4add1 100644 --- a/shim/v1/main.go +++ b/shim/v1/main.go @@ -1,5 +1,4 @@ -// Copyright 2018 The containerd Authors. -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,253 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Binary gvisor-containerd-shim is the v1 containerd shim. package main import ( - "bytes" - "context" - "flag" - "fmt" - "log" - "net" - "os" - "os/exec" - "os/signal" - "path/filepath" - "strings" - "sync" - "syscall" - - "github.com/containerd/containerd/events" - "github.com/containerd/containerd/namespaces" - "github.com/containerd/containerd/sys" - "github.com/containerd/containerd/sys/reaper" - "github.com/containerd/ttrpc" - "github.com/containerd/typeurl" - "github.com/gogo/protobuf/types" - "golang.org/x/sys/unix" - - "gvisor.dev/gvisor/pkg/shim/runsc" - "gvisor.dev/gvisor/pkg/shim/v1/shim" -) - -var ( - debugFlag bool - namespaceFlag string - socketFlag string - addressFlag string - workdirFlag string - runtimeRootFlag string - containerdBinaryFlag string - shimConfigFlag string + "gvisor.dev/gvisor/shim/v1/cli" ) -// Containerd defaults to runc, unless another runtime is explicitly specified. -// We keep the same default to make the default behavior consistent. -const defaultRoot = "/run/containerd/runc" - -func init() { - flag.BoolVar(&debugFlag, "debug", false, "enable debug output in logs") - flag.StringVar(&namespaceFlag, "namespace", "", "namespace that owns the shim") - flag.StringVar(&socketFlag, "socket", "", "abstract socket path to serve") - flag.StringVar(&addressFlag, "address", "", "grpc address back to main containerd") - flag.StringVar(&workdirFlag, "workdir", "", "path used to storge large temporary data") - flag.StringVar(&runtimeRootFlag, "runtime-root", defaultRoot, "root directory for the runtime") - - // Currently, the `containerd publish` utility is embedded in the - // daemon binary. The daemon invokes `containerd-shim - // -containerd-binary ...` with its own os.Executable() path. - flag.StringVar(&containerdBinaryFlag, "containerd-binary", "containerd", "path to containerd binary (used for `containerd publish`)") - flag.StringVar(&shimConfigFlag, "config", "/etc/containerd/runsc.toml", "path to the shim configuration file") -} - func main() { - flag.Parse() - - // This is a hack. Exec current process to run standard containerd-shim - // if runtime root is not `runsc`. We don't need this for shim v2 api. - if filepath.Base(runtimeRootFlag) != "runsc" { - if err := executeRuncShim(); err != nil { - fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) - os.Exit(1) - } - } - - // Run regular shim if needed. - if err := executeShim(); err != nil { - fmt.Fprintf(os.Stderr, "gvisor-containerd-shim: %s\n", err) - os.Exit(1) - } -} - -// executeRuncShim execs current process to a containerd-shim process and -// retains all flags and envs. -func executeRuncShim() error { - c, err := loadConfig(shimConfigFlag) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to load shim config: %w", err) - } - shimPath := c.RuncShim - if shimPath == "" { - shimPath, err = exec.LookPath("containerd-shim") - if err != nil { - return fmt.Errorf("lookup containerd-shim failed: %w", err) - } - } - - args := append([]string{shimPath}, os.Args[1:]...) - if err := syscall.Exec(shimPath, args, os.Environ()); err != nil { - return fmt.Errorf("exec containerd-shim @ %q failed: %w", shimPath, err) - } - return nil -} - -func executeShim() error { - // start handling signals as soon as possible so that things are - // properly reaped or if runtime exits before we hit the handler. - signals, err := setupSignals() - if err != nil { - return err - } - path, err := os.Getwd() - if err != nil { - return err - } - server, err := ttrpc.NewServer(ttrpc.WithServerHandshaker(ttrpc.UnixSocketRequireSameUser())) - if err != nil { - return fmt.Errorf("failed creating server: %w", err) - } - c, err := loadConfig(shimConfigFlag) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to load shim config: %w", err) - } - sv, err := shim.NewService( - shim.Config{ - Path: path, - Namespace: namespaceFlag, - WorkDir: workdirFlag, - RuntimeRoot: runtimeRootFlag, - RunscConfig: c.RunscConfig, - }, - &remoteEventsPublisher{address: addressFlag}, - ) - if err != nil { - return err - } - registerShimService(server, sv) - if err := serve(server, socketFlag); err != nil { - return err - } - return handleSignals(signals, server, sv) -} - -// serve serves the ttrpc API over a unix socket at the provided path this -// function does not block. -func serve(server *ttrpc.Server, path string) error { - var ( - l net.Listener - err error - ) - if path == "" { - l, err = net.FileListener(os.NewFile(3, "socket")) - path = "[inherited from parent]" - } else { - if len(path) > 106 { - return fmt.Errorf("%q: unix socket path too long (> 106)", path) - } - l, err = net.Listen("unix", "\x00"+path) - } - if err != nil { - return err - } - go func() { - defer l.Close() - err := server.Serve(context.Background(), l) - if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { - log.Fatalf("ttrpc server failure: %v", err) - } - }() - return nil -} - -// setupSignals creates a new signal handler for all signals and sets the shim -// as a sub-reaper so that the container processes are reparented. -func setupSignals() (chan os.Signal, error) { - signals := make(chan os.Signal, 32) - signal.Notify(signals, unix.SIGTERM, unix.SIGINT, unix.SIGCHLD, unix.SIGPIPE) - // make sure runc is setup to use the monitor for waiting on processes. - // TODO(random-liu): Move shim/reaper.go to a separate package. - runsc.Monitor = reaper.Default - // Set the shim as the subreaper for all orphaned processes created by - // the container. - if err := unix.Prctl(unix.PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0); err != nil { - return nil, err - } - return signals, nil -} - -func handleSignals(signals chan os.Signal, server *ttrpc.Server, sv *shim.Service) error { - var ( - termOnce sync.Once - done = make(chan struct{}) - ) - - for { - select { - case <-done: - return nil - case s := <-signals: - switch s { - case unix.SIGCHLD: - if _, err := sys.Reap(false); err != nil { - log.Printf("reap error: %v", err) - } - case unix.SIGTERM, unix.SIGINT: - go termOnce.Do(func() { - ctx := context.TODO() - if err := server.Shutdown(ctx); err != nil { - log.Printf("failed to shutdown server: %v", err) - } - // Ensure our child is dead if any. - sv.Kill(ctx, &KillRequest{ - Signal: uint32(syscall.SIGKILL), - All: true, - }) - sv.Delete(context.Background(), &types.Empty{}) - close(done) - }) - case unix.SIGPIPE: - } - } - } -} - -type remoteEventsPublisher struct { - address string -} - -func (l *remoteEventsPublisher) Publish(ctx context.Context, topic string, event events.Event) error { - ns, _ := namespaces.Namespace(ctx) - encoded, err := typeurl.MarshalAny(event) - if err != nil { - return err - } - data, err := encoded.Marshal() - if err != nil { - return err - } - cmd := exec.CommandContext(ctx, containerdBinaryFlag, "--address", l.address, "publish", "--topic", topic, "--namespace", ns) - cmd.Stdin = bytes.NewReader(data) - c, err := reaper.Default.Start(cmd) - if err != nil { - return err - } - status, err := reaper.Default.Wait(cmd, c) - if err != nil { - return fmt.Errorf("failed to publish event: %w", err) - } - if status != 0 { - return fmt.Errorf("failed to publish event: status %d", status) - } - return nil + cli.Main() } diff --git a/shim/v2/BUILD b/shim/v2/BUILD index 8de9ac0ba..b4a107d27 100644 --- a/shim/v2/BUILD +++ b/shim/v2/BUILD @@ -4,15 +4,10 @@ package(licenses = ["notice"]) go_binary( name = "containerd-shim-runsc-v1", - srcs = [ - "main.go", - ], + srcs = ["main.go"], static = True, visibility = [ "//visibility:public", ], - deps = [ - "//pkg/shim/v2", - "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", - ], + deps = ["//shim/v2/cli"], ) diff --git a/shim/v2/cli/BUILD b/shim/v2/cli/BUILD new file mode 100644 index 000000000..6681e0772 --- /dev/null +++ b/shim/v2/cli/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "cli", + srcs = ["cli.go"], + visibility = [ + "//:__pkg__", + "//shim/v2:__pkg__", + ], + deps = [ + "//pkg/shim/v2", + "@com_github_containerd_containerd//runtime/v2/shim:go_default_library", + ], +) diff --git a/shim/v2/cli/cli.go b/shim/v2/cli/cli.go new file mode 100644 index 000000000..3d6644feb --- /dev/null +++ b/shim/v2/cli/cli.go @@ -0,0 +1,28 @@ +// Copyright 2018 The containerd Authors. +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cli defines the command line interface for the V2 shim. +package cli + +import ( + "github.com/containerd/containerd/runtime/v2/shim" + + "gvisor.dev/gvisor/pkg/shim/v2" +) + +// Main is the main entrypoint. +func Main() { + shim.Run("io.containerd.runsc.v1", v2.New) +} diff --git a/shim/v2/main.go b/shim/v2/main.go index 753871eea..3680cdf9c 100644 --- a/shim/v2/main.go +++ b/shim/v2/main.go @@ -1,5 +1,4 @@ -// Copyright 2018 The containerd Authors. -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Binary containerd-shim-runsc-v1 is the v2 containerd shim (implementing the formal v1 API). package main import ( - "github.com/containerd/containerd/runtime/v2/shim" - - "gvisor.dev/gvisor/pkg/shim/v2" + "gvisor.dev/gvisor/shim/v2/cli" ) func main() { - shim.Run("io.containerd.runsc.v1", v2.New) + cli.Main() } diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md index ffa96ba98..fe0976ba5 100644 --- a/test/packetimpact/README.md +++ b/test/packetimpact/README.md @@ -694,6 +694,13 @@ func TestMyTcpTest(t *testing.T) { } ``` +### Adding a new packetimpact test + +* Create a go test in the [tests directory](tests/) +* Add a `packetimpact_testbench` rule in [BUILD](tests/BUILD) +* Add the test into the `ALL_TESTS` list in [defs.bzl](runner/defs.bzl), + otherwise you will see an error message complaining about a missing test. + ## Other notes * The time between receiving a SYN-ACK and replying with an ACK in `Handshake` diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index f56d3c42e..1546d0d51 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -110,29 +110,15 @@ def packetimpact_netstack_test( **kwargs ) -def packetimpact_go_test(name, size = "small", pure = True, expect_native_failure = False, expect_netstack_failure = False, **kwargs): +def packetimpact_go_test(name, expect_native_failure = False, expect_netstack_failure = False): """Add packetimpact tests written in go. Args: name: name of the test - size: size of the test - pure: make a static go binary expect_native_failure: the test must fail natively expect_netstack_failure: the test must fail for Netstack - **kwargs: all the other args, forwarded to go_test """ testbench_binary = name + "_test" - go_test( - name = testbench_binary, - size = size, - pure = pure, - nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. - tags = [ - "local", - "manual", - ], - **kwargs - ) packetimpact_native_test( name = name, expect_failure = expect_native_failure, @@ -143,3 +129,156 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_native_failur expect_failure = expect_netstack_failure, testbench_binary = testbench_binary, ) + +def packetimpact_testbench(name, size = "small", pure = True, **kwargs): + """Build packetimpact testbench written in go. + + Args: + name: name of the test + size: size of the test + pure: make a static go binary + **kwargs: all the other args, forwarded to go_test + """ + go_test( + name = name + "_test", + size = size, + pure = pure, + nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + tags = [ + "local", + "manual", + ], + **kwargs + ) + +PacketimpactTestInfo = provider( + doc = "Provide information for packetimpact tests", + fields = ["name", "expect_netstack_failure"], +) + +ALL_TESTS = [ + PacketimpactTestInfo( + name = "fin_wait2_timeout", + ), + PacketimpactTestInfo( + name = "ipv4_id_uniqueness", + ), + PacketimpactTestInfo( + name = "udp_discard_mcast_source_addr", + ), + PacketimpactTestInfo( + name = "udp_recv_mcast_bcast", + ), + PacketimpactTestInfo( + name = "udp_any_addr_recv_unicast", + ), + PacketimpactTestInfo( + name = "udp_icmp_error_propagation", + ), + PacketimpactTestInfo( + name = "tcp_reordering", + # TODO(b/139368047): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "tcp_window_shrink", + ), + PacketimpactTestInfo( + name = "tcp_zero_window_probe", + ), + PacketimpactTestInfo( + name = "tcp_zero_window_probe_retransmit", + ), + PacketimpactTestInfo( + name = "tcp_zero_window_probe_usertimeout", + ), + PacketimpactTestInfo( + name = "tcp_retransmits", + ), + PacketimpactTestInfo( + name = "tcp_outside_the_window", + ), + PacketimpactTestInfo( + name = "tcp_noaccept_close_rst", + ), + PacketimpactTestInfo( + name = "tcp_send_window_sizes_piggyback", + ), + PacketimpactTestInfo( + name = "tcp_unacc_seq_ack", + ), + PacketimpactTestInfo( + name = "tcp_paws_mechanism", + # TODO(b/156682000): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "tcp_user_timeout", + ), + PacketimpactTestInfo( + name = "tcp_queue_receive_in_syn_sent", + ), + PacketimpactTestInfo( + name = "tcp_synsent_reset", + ), + PacketimpactTestInfo( + name = "tcp_synrcvd_reset", + ), + PacketimpactTestInfo( + name = "tcp_network_unreachable", + ), + PacketimpactTestInfo( + name = "tcp_cork_mss", + ), + PacketimpactTestInfo( + name = "tcp_handshake_window_size", + ), + PacketimpactTestInfo( + name = "tcp_timewait_reset", + # TODO(b/168523247): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "tcp_queue_send_in_syn_sent", + ), + PacketimpactTestInfo( + name = "icmpv6_param_problem", + # TODO(b/153485026): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "ipv6_unknown_options_action", + # TODO(b/159928940): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( + name = "ipv6_fragment_reassembly", + ), + PacketimpactTestInfo( + name = "udp_send_recv_dgram", + ), + PacketimpactTestInfo( + name = "tcp_linger", + ), + PacketimpactTestInfo( + name = "tcp_rcv_buf_space", + ), +] + +def validate_all_tests(): + """ + Make sure that ALL_TESTS list is in sync with the rules in BUILD. + + This function is order-dependent, it is intended to be used after + all packetimpact_testbench rules and before using ALL_TESTS list + at the end of BUILD. + """ + all_tests_dict = {} # there is no set, using dict to approximate. + for test in ALL_TESTS: + rule_name = test.name + "_test" + all_tests_dict[rule_name] = True + if not native.existing_rule(rule_name): + fail("%s does not have a packetimpact_testbench rule in BUILD" % test.name) + for name in native.existing_rules(): + if name.endswith("_test") and name not in all_tests_dict: + fail("%s is not declared in ALL_TESTS list in defs.bzl" % name[:-5]) diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 11db49e39..8c2de5a9f 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -1,11 +1,11 @@ -load("//test/packetimpact/runner:defs.bzl", "packetimpact_go_test") +load("//test/packetimpact/runner:defs.bzl", "ALL_TESTS", "packetimpact_go_test", "packetimpact_testbench", "validate_all_tests") package( default_visibility = ["//test/packetimpact:__subpackages__"], licenses = ["notice"], ) -packetimpact_go_test( +packetimpact_testbench( name = "fin_wait2_timeout", srcs = ["fin_wait2_timeout_test.go"], deps = [ @@ -15,7 +15,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "ipv4_id_uniqueness", srcs = ["ipv4_id_uniqueness_test.go"], deps = [ @@ -26,7 +26,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "udp_discard_mcast_source_addr", srcs = ["udp_discard_mcast_source_addr_test.go"], deps = [ @@ -37,7 +37,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "udp_recv_mcast_bcast", srcs = ["udp_recv_mcast_bcast_test.go"], deps = [ @@ -49,7 +49,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "udp_any_addr_recv_unicast", srcs = ["udp_any_addr_recv_unicast_test.go"], deps = [ @@ -60,7 +60,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "udp_icmp_error_propagation", srcs = ["udp_icmp_error_propagation_test.go"], deps = [ @@ -71,11 +71,9 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_reordering", srcs = ["tcp_reordering_test.go"], - # TODO(b/139368047): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", @@ -84,7 +82,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_window_shrink", srcs = ["tcp_window_shrink_test.go"], deps = [ @@ -94,7 +92,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_zero_window_probe", srcs = ["tcp_zero_window_probe_test.go"], deps = [ @@ -104,7 +102,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_zero_window_probe_retransmit", srcs = ["tcp_zero_window_probe_retransmit_test.go"], deps = [ @@ -114,7 +112,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_zero_window_probe_usertimeout", srcs = ["tcp_zero_window_probe_usertimeout_test.go"], deps = [ @@ -124,7 +122,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_retransmits", srcs = ["tcp_retransmits_test.go"], deps = [ @@ -134,7 +132,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_outside_the_window", srcs = ["tcp_outside_the_window_test.go"], deps = [ @@ -145,7 +143,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_noaccept_close_rst", srcs = ["tcp_noaccept_close_rst_test.go"], deps = [ @@ -155,7 +153,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_send_window_sizes_piggyback", srcs = ["tcp_send_window_sizes_piggyback_test.go"], deps = [ @@ -165,7 +163,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_unacc_seq_ack", srcs = ["tcp_unacc_seq_ack_test.go"], deps = [ @@ -176,11 +174,9 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_paws_mechanism", srcs = ["tcp_paws_mechanism_test.go"], - # TODO(b/156682000): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", @@ -189,7 +185,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_user_timeout", srcs = ["tcp_user_timeout_test.go"], deps = [ @@ -199,7 +195,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_queue_receive_in_syn_sent", srcs = ["tcp_queue_receive_in_syn_sent_test.go"], deps = [ @@ -209,7 +205,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_synsent_reset", srcs = ["tcp_synsent_reset_test.go"], deps = [ @@ -219,7 +215,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_synrcvd_reset", srcs = ["tcp_synrcvd_reset_test.go"], deps = [ @@ -229,7 +225,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_network_unreachable", srcs = ["tcp_network_unreachable_test.go"], deps = [ @@ -239,7 +235,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_cork_mss", srcs = ["tcp_cork_mss_test.go"], deps = [ @@ -249,7 +245,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_handshake_window_size", srcs = ["tcp_handshake_window_size_test.go"], deps = [ @@ -259,11 +255,9 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_timewait_reset", srcs = ["tcp_timewait_reset_test.go"], - # TODO(b/168523247): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip/header", "//test/packetimpact/testbench", @@ -271,7 +265,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_queue_send_in_syn_sent", srcs = ["tcp_queue_send_in_syn_sent_test.go"], deps = [ @@ -281,11 +275,9 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "icmpv6_param_problem", srcs = ["icmpv6_param_problem_test.go"], - # TODO(b/153485026): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", @@ -294,11 +286,9 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "ipv6_unknown_options_action", srcs = ["ipv6_unknown_options_action_test.go"], - # TODO(b/159928940): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", @@ -307,7 +297,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "ipv6_fragment_reassembly", srcs = ["ipv6_fragment_reassembly_test.go"], deps = [ @@ -319,7 +309,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "udp_send_recv_dgram", srcs = ["udp_send_recv_dgram_test.go"], deps = [ @@ -329,7 +319,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_linger", srcs = ["tcp_linger_test.go"], deps = [ @@ -339,7 +329,7 @@ packetimpact_go_test( ], ) -packetimpact_go_test( +packetimpact_testbench( name = "tcp_rcv_buf_space", srcs = ["tcp_rcv_buf_space_test.go"], deps = [ @@ -348,3 +338,10 @@ packetimpact_go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +validate_all_tests() + +[packetimpact_go_test( + name = t.name, + expect_netstack_failure = hasattr(t, "expect_netstack_failure"), +) for t in ALL_TESTS] diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 36b7f1b97..572f39a5d 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -3624,6 +3624,7 @@ cc_binary( "//test/util:signal_util", "//test/util:test_util", "//test/util:thread_util", + "//test/util:timer_util", ], ) diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index 549141cbb..b286e84fe 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -216,14 +216,29 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - // Register a signal handler for SIGALRM and set an alarm that will go off - // while blocking in the subsequent flock() call. This will interrupt flock() - // and cause it to return EINTR. + // Make sure that a blocking flock() call will return EINTR when interrupted + // by a signal. Create a timer that will go off while blocking on flock(), and + // register the corresponding signal handler. + auto timer = ASSERT_NO_ERRNO_AND_VALUE( + TimerCreate(CLOCK_MONOTONIC, sigevent_t{ + .sigev_signo = SIGALRM, + .sigev_notify = SIGEV_SIGNAL, + })); + struct sigaction act = {}; act.sa_handler = trivial_handler; ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); - ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + + // Now that the signal handler is registered, set the timer. Set an interval + // so that it's ok if the timer goes off before we call flock. + ASSERT_NO_ERRNO( + timer.Set(0, itimerspec{ + .it_interval = absl::ToTimespec(absl::Milliseconds(10)), + .it_value = absl::ToTimespec(absl::Milliseconds(10)), + })); + ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallFailsWithErrno(EINTR)); + timer.reset(); // Unlock ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); @@ -258,14 +273,29 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); - // Register a signal handler for SIGALRM and set an alarm that will go off - // while blocking in the subsequent flock() call. This will interrupt flock() - // and cause it to return EINTR. + // Make sure that a blocking flock() call will return EINTR when interrupted + // by a signal. Create a timer that will go off while blocking on flock(), and + // register the corresponding signal handler. + auto timer = ASSERT_NO_ERRNO_AND_VALUE( + TimerCreate(CLOCK_MONOTONIC, sigevent_t{ + .sigev_signo = SIGALRM, + .sigev_notify = SIGEV_SIGNAL, + })); + struct sigaction act = {}; act.sa_handler = trivial_handler; ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); - ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + + // Now that the signal handler is registered, set the timer. Set an interval + // so that it's ok if the timer goes off before we call flock. + ASSERT_NO_ERRNO( + timer.Set(0, itimerspec{ + .it_interval = absl::ToTimespec(absl::Milliseconds(10)), + .it_value = absl::ToTimespec(absl::Milliseconds(10)), + })); + ASSERT_THAT(flock(fd.get(), LOCK_EX), SyscallFailsWithErrno(EINTR)); + timer.reset(); // Unlock ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index ae65d366b..b96907b30 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -93,15 +93,15 @@ TEST(MknodTest, MknodOnExistingPathFails) { } TEST(MknodTest, UnimplementedTypesReturnError) { - const std::string path = NewTempAbsPath(); + // TODO(gvisor.dev/issue/1624): These file types are supported by some + // filesystems in VFS2, so this test should be deleted along with VFS1. + SKIP_IF(!IsRunningWithVFS1()); - if (IsRunningWithVFS1()) { - ASSERT_THAT(mknod(path.c_str(), S_IFSOCK, 0), - SyscallFailsWithErrno(EOPNOTSUPP)); - } - // These will fail on linux as well since we don't have CAP_MKNOD. - ASSERT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM)); - ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); + const std::string path = NewTempAbsPath(); + EXPECT_THAT(mknod(path.c_str(), S_IFSOCK, 0), + SyscallFailsWithErrno(EOPNOTSUPP)); + EXPECT_THAT(mknod(path.c_str(), S_IFCHR, 0), SyscallFailsWithErrno(EPERM)); + EXPECT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); } TEST(MknodTest, Socket) { diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index c097c9187..06d9dbf65 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -569,30 +569,38 @@ TEST_P(PipeTest, Streaming) { // Size() requires 2 syscalls, call it once and remember the value. const int pipe_size = Size(); + const size_t streamed_bytes = 4 * pipe_size; absl::Notification notify; - ScopedThread t([this, ¬ify, pipe_size]() { + ScopedThread t([&, this]() { + std::vector<char> buf(1024); // Don't start until it's full. notify.WaitForNotification(); - for (int i = 0; i < pipe_size; i++) { - int rbuf; - ASSERT_THAT(read(rfd_.get(), &rbuf, sizeof(rbuf)), - SyscallSucceedsWithValue(sizeof(rbuf))); - EXPECT_EQ(rbuf, i); + ssize_t total = 0; + while (total < streamed_bytes) { + ASSERT_THAT(read(rfd_.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + total += buf.size(); } }); // Write 4 bytes * pipe_size. It will fill up the pipe once, notify the reader // to start. Then we write pipe size worth 3 more times to ensure the reader // can follow along. + // + // The size of each write (which is determined by buf.size()) must be smaller + // than the size of the pipe (which, in the "smallbuffer" configuration, is 1 + // page) for the check for notify.Notify() below to be correct. + std::vector<char> buf(1024); + RandomizeBuffer(buf.data(), buf.size()); ssize_t total = 0; - for (int i = 0; i < pipe_size; i++) { - ssize_t written = write(wfd_.get(), &i, sizeof(i)); - ASSERT_THAT(written, SyscallSucceedsWithValue(sizeof(i))); - total += written; + while (total < streamed_bytes) { + ASSERT_THAT(write(wfd_.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + total += buf.size(); // Is the next write about to fill up the buffer? Wake up the reader once. - if (total < pipe_size && (total + written) >= pipe_size) { + if (total < pipe_size && (total + buf.size()) >= pipe_size) { notify.Notify(); } } diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index f4b69c46c..831d96262 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -14,6 +14,7 @@ #include "test/syscalls/linux/socket_ip_tcp_generic.h" +#include <fcntl.h> #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -979,6 +980,56 @@ TEST_P(TCPSocketPairTest, SetTCPUserTimeoutAboveZero) { EXPECT_EQ(get, kAbove); } +#ifdef __linux__ +TEST_P(TCPSocketPairTest, SpliceFromPipe) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize / 2); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + + EXPECT_THAT( + splice(rfd.get(), nullptr, sockets->first_fd(), nullptr, kPageSize, 0), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> rbuf(buf.size()); + ASSERT_THAT(read(sockets->second_fd(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(buf.size())); + EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); +} + +TEST_P(TCPSocketPairTest, SpliceToPipe) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize / 2); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(sockets->first_fd(), buf.data(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + shutdown(sockets->first_fd(), SHUT_WR); + EXPECT_THAT( + splice(sockets->second_fd(), nullptr, wfd.get(), nullptr, kPageSize, 0), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> rbuf(buf.size()); + ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(buf.size())); + EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); +} +#endif // __linux__ + TEST_P(TCPSocketPairTest, SetTCPWindowClampBelowMinRcvBufConnectedSocket) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Discover minimum receive buf by setting a really low value diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index b3fcf8e7c..241ddad74 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <arpa/inet.h> +#include <fcntl.h> #include <ifaddrs.h> #include <linux/if.h> #include <linux/netlink.h> @@ -335,6 +336,49 @@ TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { EXPECT_EQ((msg.msg_flags & MSG_TRUNC), MSG_TRUNC); } +TEST(NetlinkRouteTest, SpliceFromPipe) { + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link.index; + + ASSERT_THAT(write(wfd.get(), &req, sizeof(req)), + SyscallSucceedsWithValue(sizeof(req))); + + EXPECT_THAT(splice(rfd.get(), nullptr, fd.get(), nullptr, sizeof(req) + 1, 0), + SyscallSucceedsWithValue(sizeof(req))); + close(wfd.release()); + EXPECT_THAT(splice(rfd.get(), nullptr, fd.get(), nullptr, sizeof(req) + 1, 0), + SyscallSucceedsWithValue(0)); + + bool found = false; + ASSERT_NO_ERRNO(NetlinkResponse( + fd, + [&](const struct nlmsghdr* hdr) { + CheckLinkMsg(hdr, loopback_link); + found = true; + }, + false)); + EXPECT_TRUE(found) << "Netlink response does not contain any links."; +} + TEST(NetlinkRouteTest, MsgTruncMsgHdrMsgTrunc) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index 952eecfe8..bdebea321 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -67,10 +67,21 @@ PosixError NetlinkRequestResponse( RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0)); + return NetlinkResponse(fd, fn, expect_nlmsgerr); +} + +PosixError NetlinkResponse( + const FileDescriptor& fd, + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr) { constexpr size_t kBufferSize = 4096; std::vector<char> buf(kBufferSize); + struct iovec iov = {}; iov.iov_base = buf.data(); iov.iov_len = buf.size(); + struct msghdr msg = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; // If NLM_F_MULTI is set, response is a series of messages that ends with a // NLMSG_DONE message. diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index e13ead406..f97276d44 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -41,6 +41,14 @@ PosixError NetlinkRequestResponse( const std::function<void(const struct nlmsghdr* hdr)>& fn, bool expect_nlmsgerr); +// Call fn on all response netlink messages. +// +// To be used on requests with NLM_F_MULTI reponses. +PosixError NetlinkResponse( + const FileDescriptor& fd, + const std::function<void(const struct nlmsghdr* hdr)>& fn, + bool expect_nlmsgerr); + // Send the passed request and call fn on all response netlink messages. // // To be used on requests without NLM_F_MULTI reponses. diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc index 4b3c44527..cac94d9e1 100644 --- a/test/syscalls/linux/timers.cc +++ b/test/syscalls/linux/timers.cc @@ -33,6 +33,7 @@ #include "test/util/signal_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/timer_util.h" ABSL_FLAG(bool, timers_test_sleep, false, "If true, sleep forever instead of running tests."); @@ -215,99 +216,6 @@ TEST(TimerTest, ProcessKilledOnCPUHardLimit) { EXPECT_GE(cpu, kHardLimit); } -// RAII type for a kernel "POSIX" interval timer. (The kernel provides system -// calls such as timer_create that behave very similarly, but not identically, -// to those described by timer_create(2); in particular, the kernel does not -// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on -// these kernel interval timers.) -// -// Compare implementation to FileDescriptor. -class IntervalTimer { - public: - IntervalTimer() = default; - - explicit IntervalTimer(int id) { set_id(id); } - - IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {} - - IntervalTimer& operator=(IntervalTimer&& orig) { - if (this == &orig) return *this; - reset(orig.release()); - return *this; - } - - IntervalTimer(const IntervalTimer& other) = delete; - IntervalTimer& operator=(const IntervalTimer& other) = delete; - - ~IntervalTimer() { reset(); } - - int get() const { return id_; } - - int release() { - int const id = id_; - id_ = -1; - return id; - } - - void reset() { reset(-1); } - - void reset(int id) { - if (id_ >= 0) { - TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0); - MaybeSave(); - } - set_id(id); - } - - PosixErrorOr<struct itimerspec> Set( - int flags, const struct itimerspec& new_value) const { - struct itimerspec old_value = {}; - if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) { - return PosixError(errno, "timer_settime"); - } - MaybeSave(); - return old_value; - } - - PosixErrorOr<struct itimerspec> Get() const { - struct itimerspec curr_value = {}; - if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) { - return PosixError(errno, "timer_gettime"); - } - MaybeSave(); - return curr_value; - } - - PosixErrorOr<int> Overruns() const { - int rv = syscall(SYS_timer_getoverrun, id_); - if (rv < 0) { - return PosixError(errno, "timer_getoverrun"); - } - MaybeSave(); - return rv; - } - - private: - void set_id(int id) { id_ = std::max(id, -1); } - - // Kernel timer_t is int; glibc timer_t is void*. - int id_ = -1; -}; - -PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid, - const struct sigevent& sev) { - int timerid; - int ret = syscall(SYS_timer_create, clockid, &sev, &timerid); - if (ret < 0) { - return PosixError(errno, "timer_create"); - } - if (ret > 0) { - return PosixError(EINVAL, "timer_create should never return positive"); - } - MaybeSave(); - return IntervalTimer(timerid); -} - // See timerfd.cc:TimerSlack() for rationale. constexpr absl::Duration kTimerSlack = absl::Milliseconds(500); diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc index d0aea8e6a..fbac94912 100644 --- a/test/util/save_util_linux.cc +++ b/test/util/save_util_linux.cc @@ -46,4 +46,4 @@ void MaybeSave() { } // namespace testing } // namespace gvisor -#endif +#endif // __linux__ diff --git a/test/util/timer_util.cc b/test/util/timer_util.cc index 43a26b0d3..75cfc4f40 100644 --- a/test/util/timer_util.cc +++ b/test/util/timer_util.cc @@ -23,5 +23,23 @@ absl::Time Now(clockid_t id) { return absl::TimeFromTimespec(now); } +#ifdef __linux__ + +PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid, + const struct sigevent& sev) { + int timerid; + int ret = syscall(SYS_timer_create, clockid, &sev, &timerid); + if (ret < 0) { + return PosixError(errno, "timer_create"); + } + if (ret > 0) { + return PosixError(EINVAL, "timer_create should never return positive"); + } + MaybeSave(); + return IntervalTimer(timerid); +} + +#endif // __linux__ + } // namespace testing } // namespace gvisor diff --git a/test/util/timer_util.h b/test/util/timer_util.h index 31aea4fc6..926e6632f 100644 --- a/test/util/timer_util.h +++ b/test/util/timer_util.h @@ -16,6 +16,9 @@ #define GVISOR_TEST_UTIL_TIMER_UTIL_H_ #include <errno.h> +#ifdef __linux__ +#include <sys/syscall.h> +#endif #include <sys/time.h> #include <functional> @@ -30,6 +33,9 @@ namespace gvisor { namespace testing { +// Returns the current time. +absl::Time Now(clockid_t id); + // MonotonicTimer is a simple timer that uses a monotonic clock. class MonotonicTimer { public: @@ -65,8 +71,92 @@ inline PosixErrorOr<Cleanup> ScopedItimer(int which, })); } -// Returns the current time. -absl::Time Now(clockid_t id); +#ifdef __linux__ + +// RAII type for a kernel "POSIX" interval timer. (The kernel provides system +// calls such as timer_create that behave very similarly, but not identically, +// to those described by timer_create(2); in particular, the kernel does not +// implement SIGEV_THREAD. glibc builds POSIX-compliant interval timers based on +// these kernel interval timers.) +// +// Compare implementation to FileDescriptor. +class IntervalTimer { + public: + IntervalTimer() = default; + + explicit IntervalTimer(int id) { set_id(id); } + + IntervalTimer(IntervalTimer&& orig) : id_(orig.release()) {} + + IntervalTimer& operator=(IntervalTimer&& orig) { + if (this == &orig) return *this; + reset(orig.release()); + return *this; + } + + IntervalTimer(const IntervalTimer& other) = delete; + IntervalTimer& operator=(const IntervalTimer& other) = delete; + + ~IntervalTimer() { reset(); } + + int get() const { return id_; } + + int release() { + int const id = id_; + id_ = -1; + return id; + } + + void reset() { reset(-1); } + + void reset(int id) { + if (id_ >= 0) { + TEST_PCHECK(syscall(SYS_timer_delete, id_) == 0); + MaybeSave(); + } + set_id(id); + } + + PosixErrorOr<struct itimerspec> Set( + int flags, const struct itimerspec& new_value) const { + struct itimerspec old_value = {}; + if (syscall(SYS_timer_settime, id_, flags, &new_value, &old_value) < 0) { + return PosixError(errno, "timer_settime"); + } + MaybeSave(); + return old_value; + } + + PosixErrorOr<struct itimerspec> Get() const { + struct itimerspec curr_value = {}; + if (syscall(SYS_timer_gettime, id_, &curr_value) < 0) { + return PosixError(errno, "timer_gettime"); + } + MaybeSave(); + return curr_value; + } + + PosixErrorOr<int> Overruns() const { + int rv = syscall(SYS_timer_getoverrun, id_); + if (rv < 0) { + return PosixError(errno, "timer_getoverrun"); + } + MaybeSave(); + return rv; + } + + private: + void set_id(int id) { id_ = std::max(id, -1); } + + // Kernel timer_t is int; glibc timer_t is void*. + int id_ = -1; +}; + +// A wrapper around timer_create(2). +PosixErrorOr<IntervalTimer> TimerCreate(clockid_t clockid, + const struct sigevent& sev); + +#endif // __linux__ } // namespace testing } // namespace gvisor diff --git a/tools/bazel.mk b/tools/bazel.mk index 25575c02c..88431ce66 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -14,12 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Make hacks. +EMPTY := +SPACE := $(EMPTY) $(EMPTY) + # See base Makefile. SHELL=/bin/bash -o pipefail BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \ git rev-parse --abbrev-ref HEAD 2>/dev/null) | \ xargs -n 1 basename 2>/dev/null) -BUILD_ROOT := $(CURDIR)/bazel-bin/ +BUILD_ROOTS := bazel-bin/ bazel-out/ # Bazel container configuration (see below). USER ?= gvisor @@ -152,10 +156,12 @@ build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefai build_paths = $(build_cmd) 2>&1 \ | tee /proc/self/fd/2 \ - | grep " bazel-bin/" \ + | grep -A1 -E '^Target' \ + | grep -E '^ ($(subst $(SPACE),|,$(BUILD_ROOTS)))' \ | sed "s/ /\n/g" \ | strings -n 10 \ | awk '{$$1=$$1};1' \ + | xargs -n 1 -I {} readlink -f "{}" \ | xargs -n 1 -I {} sh -c "$(1)" build: bazel-server diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD index 8d4356119..d043caf06 100644 --- a/tools/bazeldefs/BUILD +++ b/tools/bazeldefs/BUILD @@ -26,43 +26,6 @@ rbe_platform( remote_execution_properties = """ properties: { name: "container-image" - value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:93f7e127196b9b653d39830c50f8b05d49ef6fd8739a9b5b8ab16e1df5399e50" - } - properties: { - name: "dockerAddCapabilities" - value: "SYS_ADMIN" - } - properties: { - name: "dockerPrivileged" - value: "true" - } - """, -) - -rbe_toolchain( - name = "cc-toolchain-clang-x86_64-default", - exec_compatible_with = [], - tags = [ - "manual", - ], - target_compatible_with = [], - toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8", - toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", -) - -# Updated versions of the above, compatible with bazel3. -rbe_platform( - name = "rbe_ubuntu1604_bazel3", - constraint_values = [ - "@bazel_tools//platforms:x86_64", - "@bazel_tools//platforms:linux", - "@bazel_tools//tools/cpp:clang", - "@bazel_toolchains_bazel3//constraints:xenial", - "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan", - ], - remote_execution_properties = """ - properties: { - name: "container-image" value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272" } properties: { @@ -77,13 +40,13 @@ rbe_platform( ) rbe_toolchain( - name = "cc-toolchain-clang-x86_64-default_bazel3", + name = "cc-toolchain-clang-x86_64-default", exec_compatible_with = [], tags = [ "manual", ], target_compatible_with = [], - toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8", + toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8", toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", ) diff --git a/tools/bazeldefs/cc.bzl b/tools/bazeldefs/cc.bzl new file mode 100644 index 000000000..7f41a0142 --- /dev/null +++ b/tools/bazeldefs/cc.bzl @@ -0,0 +1,43 @@ +"""C++ rules.""" + +load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") +load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library") + +cc_library = _cc_library +cc_flags_supplier = _cc_flags_supplier +cc_proto_library = _cc_proto_library +cc_test = _cc_test +cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" +gtest = "@com_google_googletest//:gtest" +gbenchmark = "@com_google_benchmark//:benchmark" +grpcpp = "@com_github_grpc_grpc//:grpc++" +vdso_linker_option = "-fuse-ld=gold " + +def cc_grpc_library(name, **kwargs): + _cc_grpc_library(name = name, grpc_only = True, **kwargs) + +def cc_binary(name, static = False, **kwargs): + """Run cc_binary. + + Args: + name: name of the target. + static: make a static binary if True + **kwargs: the rest of the args. + """ + if static: + # How to statically link a c++ program that uses threads, like for gRPC: + # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html + if "linkopts" not in kwargs: + kwargs["linkopts"] = [] + kwargs["linkopts"] += [ + "-static", + "-lstdc++", + "-Wl,--whole-archive", + "-lpthread", + "-Wl,--no-whole-archive", + ] + _cc_binary( + name = name, + **kwargs + ) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index cf5b1dc0d..ba186aace 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -1,35 +1,13 @@ -"""Bazel implementations of standard rules.""" +"""Meta and miscellaneous rules.""" -load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle") load("@bazel_skylib//rules:build_test.bzl", _build_test = "build_test") load("@bazel_skylib//:bzl_library.bzl", _bzl_library = "bzl_library") -load("@bazel_tools//tools/cpp:cc_flags_supplier.bzl", _cc_flags_supplier = "cc_flags_supplier") -load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test") -load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library") -load("@rules_cc//cc:defs.bzl", _cc_binary = "cc_binary", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test") -load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") -load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", _cc_grpc_library = "cc_grpc_library") build_test = _build_test bzl_library = _bzl_library -cc_library = _cc_library -cc_flags_supplier = _cc_flags_supplier -cc_proto_library = _cc_proto_library -cc_test = _cc_test -cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain" -gazelle = _gazelle -go_embed_data = _go_embed_data -go_path = _go_path -gtest = "@com_google_googletest//:gtest" -grpcpp = "@com_github_grpc_grpc//:grpc++" -gbenchmark = "@com_google_benchmark//:benchmark" loopback = "//tools/bazeldefs:loopback" -pkg_deb = _pkg_deb -pkg_tar = _pkg_tar -py_binary = native.py_binary rbe_platform = native.platform rbe_toolchain = native.toolchain -vdso_linker_option = "-fuse-ld=gold " def short_path(path): return path @@ -40,140 +18,6 @@ def proto_library(name, has_services = None, **kwargs): **kwargs ) -def cc_grpc_library(name, **kwargs): - _cc_grpc_library(name = name, grpc_only = True, **kwargs) - -def _go_proto_or_grpc_library(go_library_func, name, **kwargs): - deps = [ - dep.replace("_proto", "_go_proto") - for dep in (kwargs.pop("deps", []) or []) - ] - go_library_func( - name = name + "_go_proto", - importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto", - proto = ":" + name + "_proto", - deps = deps, - **kwargs - ) - -def go_proto_library(name, **kwargs): - _go_proto_or_grpc_library(_go_proto_library, name, **kwargs) - -def go_grpc_and_proto_libraries(name, **kwargs): - _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs) - -def cc_binary(name, static = False, **kwargs): - """Run cc_binary. - - Args: - name: name of the target. - static: make a static binary if True - **kwargs: the rest of the args. - """ - if static: - # How to statically link a c++ program that uses threads, like for gRPC: - # https://gcc.gnu.org/legacy-ml/gcc-help/2010-05/msg00029.html - if "linkopts" not in kwargs: - kwargs["linkopts"] = [] - kwargs["linkopts"] += [ - "-static", - "-lstdc++", - "-Wl,--whole-archive", - "-lpthread", - "-Wl,--no-whole-archive", - ] - _cc_binary( - name = name, - **kwargs - ) - -def go_binary(name, static = False, pure = False, x_defs = None, **kwargs): - """Build a go binary. - - Args: - name: name of the target. - static: build a static binary. - pure: build without cgo. - x_defs: additional definitions. - **kwargs: rest of the arguments are passed to _go_binary. - """ - if static: - kwargs["static"] = "on" - if pure: - kwargs["pure"] = "on" - _go_binary( - name = name, - x_defs = x_defs, - **kwargs - ) - -def go_importpath(target): - """Returns the importpath for the target.""" - return target[GoLibrary].importpath - -def go_library(name, **kwargs): - _go_library( - name = name, - importpath = "gvisor.dev/gvisor/" + native.package_name(), - **kwargs - ) - -def go_test(name, pure = False, library = None, **kwargs): - """Build a go test. - - Args: - name: name of the output binary. - pure: should it be built without cgo. - library: the library to embed. - **kwargs: rest of the arguments to pass to _go_test. - """ - if pure: - kwargs["pure"] = "on" - if library: - kwargs["embed"] = [library] - _go_test( - name = name, - **kwargs - ) - -def go_rule(rule, implementation, **kwargs): - """Wraps a rule definition with Go attributes. - - Args: - rule: rule function (typically rule or aspect). - implementation: implementation function. - **kwargs: other arguments to pass to rule. - - Returns: - The result of invoking the rule. - """ - attrs = kwargs.pop("attrs", dict()) - attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") - attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") - toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] - return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) - -def go_test_library(target): - if hasattr(target.attr, "embed") and len(target.attr.embed) > 0: - return target.attr.embed[0] - return None - -def go_context(ctx, std = False): - # We don't change anything for the standard library analysis. All Go files - # are available in all instances. Note that this includes the standard - # library sources, which are analyzed by nogo. - go_ctx = _go_context(ctx) - return struct( - go = go_ctx.go, - env = go_ctx.env, - nogo_args = [], - stdlib_srcs = go_ctx.sdk.srcs, - runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), - goos = go_ctx.sdk.goos, - goarch = go_ctx.sdk.goarch, - tags = go_ctx.tags, - ) - def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs): values = { "@bazel_tools//src/conditions:linux_x86_64": amd64, diff --git a/tools/bazeldefs/go.bzl b/tools/bazeldefs/go.bzl new file mode 100644 index 000000000..d388346a5 --- /dev/null +++ b/tools/bazeldefs/go.bzl @@ -0,0 +1,142 @@ +"""Go rules.""" + +load("@bazel_gazelle//:def.bzl", _gazelle = "gazelle") +load("@io_bazel_rules_go//go:def.bzl", "GoLibrary", _go_binary = "go_binary", _go_context = "go_context", _go_embed_data = "go_embed_data", _go_library = "go_library", _go_path = "go_path", _go_test = "go_test") +load("@io_bazel_rules_go//proto:def.bzl", _go_grpc_library = "go_grpc_library", _go_proto_library = "go_proto_library") +load("//tools/bazeldefs:defs.bzl", "select_arch", "select_system") + +gazelle = _gazelle +go_embed_data = _go_embed_data +go_path = _go_path + +def _go_proto_or_grpc_library(go_library_func, name, **kwargs): + deps = [ + dep.replace("_proto", "_go_proto") + for dep in (kwargs.pop("deps", []) or []) + ] + go_library_func( + name = name + "_go_proto", + importpath = "gvisor.dev/gvisor/" + native.package_name() + "/" + name + "_go_proto", + proto = ":" + name + "_proto", + deps = deps, + **kwargs + ) + +def go_proto_library(name, **kwargs): + _go_proto_or_grpc_library(_go_proto_library, name, **kwargs) + +def go_grpc_and_proto_libraries(name, **kwargs): + _go_proto_or_grpc_library(_go_grpc_library, name, **kwargs) + +def go_binary(name, static = False, pure = False, x_defs = None, **kwargs): + """Build a go binary. + + Args: + name: name of the target. + static: build a static binary. + pure: build without cgo. + x_defs: additional definitions. + **kwargs: rest of the arguments are passed to _go_binary. + """ + if static: + kwargs["static"] = "on" + if pure: + kwargs["pure"] = "on" + _go_binary( + name = name, + x_defs = x_defs, + **kwargs + ) + +def go_importpath(target): + """Returns the importpath for the target.""" + return target[GoLibrary].importpath + +def go_library(name, **kwargs): + _go_library( + name = name, + importpath = "gvisor.dev/gvisor/" + native.package_name(), + **kwargs + ) + +def go_test(name, pure = False, library = None, **kwargs): + """Build a go test. + + Args: + name: name of the output binary. + pure: should it be built without cgo. + library: the library to embed. + **kwargs: rest of the arguments to pass to _go_test. + """ + if pure: + kwargs["pure"] = "on" + if library: + kwargs["embed"] = [library] + _go_test( + name = name, + **kwargs + ) + +def go_rule(rule, implementation, **kwargs): + """Wraps a rule definition with Go attributes. + + Args: + rule: rule function (typically rule or aspect). + implementation: implementation function. + **kwargs: other arguments to pass to rule. + + Returns: + The result of invoking the rule. + """ + attrs = kwargs.pop("attrs", dict()) + attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") + attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") + toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] + return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) + +def go_test_library(target): + if hasattr(target.attr, "embed") and len(target.attr.embed) > 0: + return target.attr.embed[0] + return None + +def go_context(ctx, goos = None, goarch = None, std = False): + """Extracts a standard Go context struct. + + Args: + ctx: the starlark context (required). + goos: the GOOS value. + goarch: the GOARCH value. + std: ignored. + + Returns: + A context Go struct with pointers to Go toolchain components. + """ + + # We don't change anything for the standard library analysis. All Go files + # are available in all instances. Note that this includes the standard + # library sources, which are analyzed by nogo. + go_ctx = _go_context(ctx) + if goos == None: + goos = go_ctx.sdk.goos + elif goos != go_ctx.sdk.goos: + fail("Internal GOOS (%s) doesn't match GoSdk GOOS (%s)." % (goos, go_ctx.sdk.goos)) + if goarch == None: + goarch = go_ctx.sdk.goarch + elif goarch != go_ctx.sdk.goarch: + fail("Internal GOARCH (%s) doesn't match GoSdk GOARCH (%s)." % (goarch, go_ctx.sdk.goarch)) + return struct( + go = go_ctx.go, + env = go_ctx.env, + nogo_args = [], + stdlib_srcs = go_ctx.sdk.srcs, + runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), + goos = go_ctx.sdk.goos, + goarch = go_ctx.sdk.goarch, + tags = go_ctx.tags, + ) + +def select_goarch(): + return select_arch(arm64 = "arm64", amd64 = "amd64") + +def select_goos(): + return select_system(linux = "linux") diff --git a/tools/bazeldefs/pkg.bzl b/tools/bazeldefs/pkg.bzl new file mode 100644 index 000000000..56317d93f --- /dev/null +++ b/tools/bazeldefs/pkg.bzl @@ -0,0 +1,6 @@ +"""Packaging rules.""" + +load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") + +pkg_deb = _pkg_deb +pkg_tar = _pkg_tar diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index 523a42692..e5a7e23c7 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -67,6 +67,7 @@ import ( "go/token" "go/types" "io" + "log" "os" "os/exec" "path/filepath" @@ -619,7 +620,10 @@ func findReasons(pass *analysis.Pass, fdecl *ast.FuncDecl) ([]EscapeReason, bool func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { calls, err := loadObjdump() if err != nil { - return nil, err + // Note that if this analysis fails, then we don't actually + // fail the analyzer itself. We simply report every possible + // escape. In most cases this will work just fine. + log.Printf("WARNING: unable to load objdump: %v", err) } allEscapes := make(map[string][]Escapes) mergedEscapes := make(map[string]Escapes) @@ -641,6 +645,11 @@ func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { } hasCall := func(inst poser) (string, bool) { p := linePosition(inst, nil) + if calls == nil { + // See above: we don't have access to the binary + // itself, so need to include every possible call. + return "(possible)", true + } s, ok := calls[p.Simplified()] if !ok { return "", false diff --git a/tools/defs.bzl b/tools/defs.bzl index e2c72a1f6..bb291c512 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -7,39 +7,49 @@ change for Google-internal and bazel-compatible rules. load("//tools/go_stateify:defs.bzl", "go_stateify") load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps") -load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _gazelle = "gazelle", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _grpcpp = "grpcpp", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path", _vdso_linker_option = "vdso_linker_option") +load("//tools/nogo:defs.bzl", "nogo_test") +load("//tools/bazeldefs:defs.bzl", _build_test = "build_test", _bzl_library = "bzl_library", _coreutil = "coreutil", _default_installer = "default_installer", _default_net_util = "default_net_util", _loopback = "loopback", _proto_library = "proto_library", _rbe_platform = "rbe_platform", _rbe_toolchain = "rbe_toolchain", _select_arch = "select_arch", _select_system = "select_system", _short_path = "short_path") +load("//tools/bazeldefs:cc.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_grpc_library = "cc_grpc_library", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _gbenchmark = "gbenchmark", _grpcpp = "grpcpp", _gtest = "gtest", _vdso_linker_option = "vdso_linker_option") +load("//tools/bazeldefs:go.bzl", _gazelle = "gazelle", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_grpc_and_proto_libraries = "go_grpc_and_proto_libraries", _go_library = "go_library", _go_path = "go_path", _go_proto_library = "go_proto_library", _go_test = "go_test", _select_goarch = "select_goarch", _select_goos = "select_goos") +load("//tools/bazeldefs:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar") load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms") load("//tools/bazeldefs:tags.bzl", "go_suffixes") -load("//tools/nogo:defs.bzl", "nogo_test") -# Delegate directly. +# Core rules. build_test = _build_test bzl_library = _bzl_library +default_installer = _default_installer +default_net_util = _default_net_util +loopback = _loopback +select_arch = _select_arch +select_system = _select_system +short_path = _short_path +rbe_platform = _rbe_platform +rbe_toolchain = _rbe_toolchain +coreutil = _coreutil + +# C++ rules. cc_binary = _cc_binary cc_flags_supplier = _cc_flags_supplier cc_grpc_library = _cc_grpc_library cc_library = _cc_library cc_test = _cc_test cc_toolchain = _cc_toolchain -default_installer = _default_installer -default_net_util = _default_net_util gbenchmark = _gbenchmark +gtest = _gtest +grpcpp = _grpcpp +vdso_linker_option = _vdso_linker_option + +# Go rules. gazelle = _gazelle go_embed_data = _go_embed_data go_path = _go_path -gtest = _gtest -grpcpp = _grpcpp -loopback = _loopback +select_goos = _select_goos +select_goarch = _select_goarch + +# Packaging rules. pkg_deb = _pkg_deb pkg_tar = _pkg_tar -py_binary = _py_binary -select_arch = _select_arch -select_system = _select_system -short_path = _short_path -rbe_platform = _rbe_platform -rbe_toolchain = _rbe_toolchain -vdso_linker_option = _vdso_linker_option -coreutil = _coreutil # Platform options. default_platform = _default_platform @@ -66,9 +76,13 @@ def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, ** if nogo: # Note that the nogo rule applies only for go_library and go_test # targets, therefore we construct a library from the binary sources. + # This is done because the binary may not be in a form that objdump + # supports (i.e. a pure Go binary). _go_library( name = name + "_nogo_library", - **kwargs + srcs = kwargs.get("srcs", []), + deps = kwargs.get("deps", []), + testonly = 1, ) nogo_test( name = name + "_nogo", diff --git a/tools/github/main.go b/tools/github/main.go index 7a74dc033..681003eef 100644 --- a/tools/github/main.go +++ b/tools/github/main.go @@ -20,6 +20,7 @@ import ( "flag" "fmt" "io/ioutil" + "log" "os" "os/exec" "strings" @@ -34,21 +35,43 @@ var ( owner string repo string tokenFile string - path string + paths stringList commit string dryRun bool ) +type stringList []string + +func (s *stringList) String() string { + return strings.Join(*s, ",") +} + +func (s *stringList) Set(value string) error { + *s = append(*s, value) + return nil +} + // Keep the options simple for now. Supports only a single path and repo. func init() { flag.StringVar(&owner, "owner", "", "GitHub project org/owner (required, except nogo dry-run)") flag.StringVar(&repo, "repo", "", "GitHub repo (required, except nogo dry-run)") flag.StringVar(&tokenFile, "oauth-token-file", "", "file containing the GitHub token (or GITHUB_TOKEN is set)") - flag.StringVar(&path, "path", ".", "path to scan (required for revive and nogo)") + flag.Var(&paths, "path", "path(s) to scan (required for revive and nogo)") flag.StringVar(&commit, "commit", "", "commit to associated (required for nogo, except dry-run)") flag.BoolVar(&dryRun, "dry-run", false, "just print changes to be made") } +func filterPaths(paths []string) (existing []string) { + for _, path := range paths { + if _, err := os.Stat(path); err != nil { + log.Printf("WARNING: skipping %v: %v", path, err) + continue + } + existing = append(existing, path) + } + return +} + func main() { // Set defaults from the environment. repository := os.Getenv("GITHUB_REPOSITORY") @@ -83,8 +106,9 @@ func main() { flag.Usage() os.Exit(1) } - if len(path) == 0 { - fmt.Fprintln(flag.CommandLine.Output(), "missing --path option.") + filteredPaths := filterPaths(paths) + if len(filteredPaths) == 0 { + fmt.Fprintln(flag.CommandLine.Output(), "no valid --path options provided.") flag.Usage() os.Exit(1) } @@ -123,7 +147,7 @@ func main() { os.Exit(1) } // Scan the provided path. - rev := reviver.New([]string{path}, []reviver.Bugger{bugger}) + rev := reviver.New(filteredPaths, []reviver.Bugger{bugger}) if errs := rev.Run(); len(errs) > 0 { fmt.Fprintf(os.Stderr, "Encountered %d errors:\n", len(errs)) for _, err := range errs { @@ -145,7 +169,7 @@ func main() { } // Scan all findings. poster := nogo.NewFindingsPoster(client, owner, repo, commit, dryRun) - if err := poster.Walk(path); err != nil { + if err := poster.Walk(filteredPaths); err != nil { fmt.Fprintln(os.Stderr, "Error finding nogo findings:", err) os.Exit(1) } diff --git a/tools/github/nogo/nogo.go b/tools/github/nogo/nogo.go index b70dfe63b..b2bc63459 100644 --- a/tools/github/nogo/nogo.go +++ b/tools/github/nogo/nogo.go @@ -53,26 +53,31 @@ func NewFindingsPoster(client *github.Client, owner, repo, commit string, dryRun } // Walk walks the given path tree for findings files. -func (p *FindingsPoster) Walk(path string) error { - return filepath.Walk(path, func(filename string, info os.FileInfo, err error) error { - if err != nil { - return err - } - // Skip any directories or files not ending in .findings. - if !strings.HasSuffix(filename, ".findings") || info.IsDir() { +func (p *FindingsPoster) Walk(paths []string) error { + for _, path := range paths { + if err := filepath.Walk(path, func(filename string, info os.FileInfo, err error) error { + if err != nil { + return err + } + // Skip any directories or files not ending in .findings. + if !strings.HasSuffix(filename, ".findings") || info.IsDir() { + return nil + } + findings, err := util.ExtractFindingsFromFile(filename) + if err != nil { + return err + } + // Add all findings to the list. We use a map to ensure + // that each finding is unique. + for _, finding := range findings { + p.findings[finding] = struct{}{} + } return nil - } - findings, err := util.ExtractFindingsFromFile(filename) - if err != nil { + }); err != nil { return err } - // Add all findings to the list. We use a map to ensure - // that each finding is unique. - for _, finding := range findings { - p.findings[finding] = struct{}{} - } - return nil - }) + } + return nil } // Post posts all results to the GitHub API as a check run. diff --git a/tools/go_generics/defs.bzl b/tools/go_generics/defs.bzl index 33329cf28..ad97208a8 100644 --- a/tools/go_generics/defs.bzl +++ b/tools/go_generics/defs.bzl @@ -1,25 +1,32 @@ -"""Generics support via go_generics.""" +"""Generics support via go_generics. + +A Go template is similar to a go library, except that it has certain types that +can be replaced before usage. For example, one could define a templatized List +struct, whose elements are of type T, then instantiate that template for +T=segment, where "segment" is the concrete type. +""" TemplateInfo = provider( + "Information about a go_generics template.", fields = { + "unsafe": "whether the template requires unsafe code", "types": "required types", "opt_types": "optional types", "consts": "required consts", "opt_consts": "optional consts", "deps": "package dependencies", - "file": "merged template", + "template": "merged template source file", }, ) def _go_template_impl(ctx): srcs = ctx.files.srcs - output = ctx.outputs.out - - args = ["-o=%s" % output.path] + [f.path for f in srcs] + template = ctx.actions.declare_file(ctx.label.name + "_template.go") + args = ["-o=%s" % template.path] + [f.path for f in srcs] ctx.actions.run( inputs = srcs, - outputs = [output], + outputs = [template], mnemonic = "GoGenericsTemplate", progress_message = "Building Go template %s" % ctx.label, arguments = args, @@ -32,74 +39,48 @@ def _go_template_impl(ctx): consts = ctx.attr.consts, opt_consts = ctx.attr.opt_consts, deps = ctx.attr.deps, - file = output, + template = template, )] -""" -Generates a Go template from a set of Go files. - -A Go template is similar to a go library, except that it has certain types that -can be replaced before usage. For example, one could define a templatized List -struct, whose elements are of type T, then instantiate that template for -T=segment, where "segment" is the concrete type. - -Args: - name: the name of the template. - srcs: the list of source files that comprise the template. - types: the list of generic types in the template that are required to be specified. - opt_types: the list of generic types in the template that can but aren't required to be specified. - consts: the list of constants in the template that are required to be specified. - opt_consts: the list of constants in the template that can but aren't required to be specified. - deps: the list of dependencies. -""" go_template = rule( implementation = _go_template_impl, attrs = { - "srcs": attr.label_list(mandatory = True, allow_files = True), - "deps": attr.label_list(allow_files = True, cfg = "target"), - "types": attr.string_list(), - "opt_types": attr.string_list(), - "consts": attr.string_list(), - "opt_consts": attr.string_list(), + "srcs": attr.label_list(doc = "the list of source files that comprise the template", mandatory = True, allow_files = True), + "deps": attr.label_list(doc = "the standard dependency list", allow_files = True, cfg = "target"), + "types": attr.string_list(doc = "the list of generic types in the template that are required to be specified"), + "opt_types": attr.string_list(doc = "the list of generic types in the template that can but aren't required to be specified"), + "consts": attr.string_list(doc = "the list of constants in the template that are required to be specified"), + "opt_consts": attr.string_list(doc = "the list of constants in the template that can but aren't required to be specified"), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics/go_merge")), }, - outputs = { - "out": "%{name}_template.go", - }, -) - -TemplateInstanceInfo = provider( - fields = { - "srcs": "source files", - }, ) def _go_template_instance_impl(ctx): - template = ctx.attr.template[TemplateInfo] + info = ctx.attr.template[TemplateInfo] output = ctx.outputs.out # Check that all required types are defined. - for t in template.types: + for t in info.types: if t not in ctx.attr.types: fail("Missing value for type %s in %s" % (t, ctx.attr.template.label)) # Check that all defined types are expected by the template. for t in ctx.attr.types: - if (t not in template.types) and (t not in template.opt_types): + if (t not in info.types) and (t not in info.opt_types): fail("Type %s it not a parameter to %s" % (t, ctx.attr.template.label)) # Check that all required consts are defined. - for t in template.consts: + for t in info.consts: if t not in ctx.attr.consts: fail("Missing value for constant %s in %s" % (t, ctx.attr.template.label)) # Check that all defined consts are expected by the template. for t in ctx.attr.consts: - if (t not in template.consts) and (t not in template.opt_consts): + if (t not in info.consts) and (t not in info.opt_consts): fail("Const %s it not a parameter to %s" % (t, ctx.attr.template.label)) # Build the argument list. - args = ["-i=%s" % template.file.path, "-o=%s" % output.path] + args = ["-i=%s" % info.template.path, "-o=%s" % output.path] if ctx.attr.package: args.append("-p=%s" % ctx.attr.package) @@ -117,7 +98,7 @@ def _go_template_instance_impl(ctx): args.append("-anon") ctx.actions.run( - inputs = [template.file], + inputs = [info.template], outputs = [output], mnemonic = "GoGenericsInstance", progress_message = "Building Go template instance %s" % ctx.label, @@ -125,35 +106,22 @@ def _go_template_instance_impl(ctx): executable = ctx.executable._tool, ) - return [TemplateInstanceInfo( - srcs = [output], + return [DefaultInfo( + files = depset([output]), )] -""" -Instantiates a Go template by replacing all generic types with concrete ones. - -Args: - name: the name of the template instance. - template: the label of the template to be instatiated. - prefix: a prefix to be added to globals in the template. - suffix: a suffix to be added to global in the template. - types: the map from generic type names to concrete ones. - consts: the map from constant names to their values. - imports: the map from imports used in types/consts to their import paths. - package: the name of the package the instantiated template will be compiled into. -""" go_template_instance = rule( implementation = _go_template_instance_impl, attrs = { - "template": attr.label(mandatory = True), - "prefix": attr.string(), - "suffix": attr.string(), - "types": attr.string_dict(), - "consts": attr.string_dict(), - "imports": attr.string_dict(), - "anon": attr.bool(mandatory = False, default = False), - "package": attr.string(mandatory = False), - "out": attr.output(mandatory = True), + "template": attr.label(doc = "the label of the template to be instantiated", mandatory = True), + "prefix": attr.string(doc = "a prefix to be added to globals in the template"), + "suffix": attr.string(doc = "a suffix to be added to globals in the template"), + "types": attr.string_dict(doc = "the map from generic type names to concrete ones"), + "consts": attr.string_dict(doc = "the map from constant names to their values"), + "imports": attr.string_dict(doc = "the map from imports used in types/consts to their import paths"), + "anon": attr.bool(doc = "whether anoymous fields should be processed", mandatory = False, default = False), + "package": attr.string(doc = "the package for the generated source file", mandatory = False), + "out": attr.output(doc = "output file", mandatory = True), "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_generics")), }, ) diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index dd4b46f58..3c6be3339 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -1,8 +1,15 @@ -load("//tools:defs.bzl", "bzl_library", "go_library") -load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib") +load("//tools:defs.bzl", "bzl_library", "go_library", "select_goarch", "select_goos") +load("//tools/nogo:defs.bzl", "nogo_objdump_tool", "nogo_stdlib", "nogo_target") package(licenses = ["notice"]) +nogo_target( + name = "target", + goarch = select_goarch(), + goos = select_goos(), + visibility = ["//visibility:public"], +) + nogo_objdump_tool( name = "objdump_tool", visibility = ["//visibility:public"], diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 8079618ab..0853f03cf 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -473,6 +473,7 @@ func init() { "pkg/shim/v2/options/options.go:24", "pkg/shim/v2/options/options.go:26", "pkg/shim/v2/runtimeoptions/runtimeoptions.go:16", + "pkg/shim/v2/runtimeoptions/runtimeoptions_cri.go", // Generated: exempt all. "pkg/shim/v2/runtimeoptions/runtimeoptions_test.go:22", "pkg/shim/v2/service.go:15", "pkg/shim/v2/service_linux.go:18", diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index c6fcfd402..543598b52 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -1,10 +1,34 @@ """Nogo rules.""" -load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") +load("//tools/bazeldefs:go.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") -def _nogo_objdump_tool_impl(ctx): - go_ctx = go_context(ctx) +NogoTargetInfo = provider( + "information about the Go target", + fields = { + "goarch": "the build architecture (GOARCH)", + "goos": "the build OS target (GOOS)", + }, +) + +def _nogo_target_impl(ctx): + return [NogoTargetInfo( + goarch = ctx.attr.goarch, + goos = ctx.attr.goos, + )] +nogo_target = go_rule( + rule, + implementation = _nogo_target_impl, + attrs = { + # goarch is the build architecture. This will normally be provided by a + # select statement, but this information is propagated to other rules. + "goarch": attr.string(mandatory = True), + # goos is similarly the build operating system target. + "goos": attr.string(mandatory = True), + }, +) + +def _nogo_objdump_tool_impl(ctx): # Construct the magic dump command. # # Note that in some cases, the input is being fed into the tool via stdin. @@ -12,6 +36,8 @@ def _nogo_objdump_tool_impl(ctx): # we need the tool to handle this case by creating a temporary file. # # [1] https://github.com/golang/go/issues/41051 + nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) dumper = ctx.actions.declare_file(ctx.label.name) ctx.actions.write(dumper, "\n".join([ @@ -42,6 +68,12 @@ def _nogo_objdump_tool_impl(ctx): nogo_objdump_tool = go_rule( rule, implementation = _nogo_objdump_tool_impl, + attrs = { + "_nogo_target": attr.label( + default = "//tools/nogo:target", + cfg = "target", + ), + }, ) # NogoStdlibInfo is the set of standard library facts. @@ -54,9 +86,9 @@ NogoStdlibInfo = provider( ) def _nogo_stdlib_impl(ctx): - go_ctx = go_context(ctx) - # Build the standard library facts. + nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(ctx.label.name + ".facts") findings = ctx.actions.declare_file(ctx.label.name + ".findings") config = struct( @@ -70,12 +102,12 @@ def _nogo_stdlib_impl(ctx): ctx.actions.run( inputs = [config_file] + go_ctx.stdlib_srcs, outputs = [facts, findings], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), - executable = ctx.files._nogo[0], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), + executable = ctx.files._nogo_check[0], mnemonic = "GoStandardLibraryAnalysis", progress_message = "Analyzing Go Standard Library", arguments = go_ctx.nogo_args + [ - "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, + "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-stdlib=%s" % config_file.path, "-findings=%s" % findings.path, "-facts=%s" % facts.path, @@ -92,11 +124,17 @@ nogo_stdlib = go_rule( rule, implementation = _nogo_stdlib_impl, attrs = { - "_nogo": attr.label( + "_nogo_check": attr.label( default = "//tools/nogo/check:check", + cfg = "host", ), - "_objdump_tool": attr.label( + "_nogo_objdump_tool": attr.label( default = "//tools/nogo:objdump_tool", + cfg = "host", + ), + "_nogo_target": attr.label( + default = "//tools/nogo:target", + cfg = "target", ), }, ) @@ -113,20 +151,18 @@ NogoInfo = provider( "findings": "package findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", - "srcs": "original source files (for go_test support)", - "deps": "original deps (for go_test support)", + "srcs": "srcs (for go_test support)", + "deps": "deps (for go_test support)", }, ) def _nogo_aspect_impl(target, ctx): - go_ctx = go_context(ctx) - # If this is a nogo rule itself (and not the shadow of a go_library or # go_binary rule created by such a rule), then we simply return nothing. # All work is done in the shadow properties for go rules. For a proto # library, we simply skip the analysis portion but still need to return a # valid NogoInfo to reference the generated binary. - if ctx.rule.kind in ("go_library", "go_binary", "go_test", "go_tool_library"): + if ctx.rule.kind in ("go_library", "go_tool_library", "go_binary", "go_test"): srcs = ctx.rule.files.srcs deps = ctx.rule.attr.deps elif ctx.rule.kind in ("go_proto_library", "go_wrap_cc"): @@ -200,10 +236,13 @@ def _nogo_aspect_impl(target, ctx): inputs += info.binaries # Add the standard library facts. - stdlib_facts = ctx.attr._nogo_stdlib[NogoStdlibInfo].facts + stdlib_info = ctx.attr._nogo_stdlib[NogoStdlibInfo] + stdlib_facts = stdlib_info.facts inputs.append(stdlib_facts) # The nogo tool operates on a configuration serialized in JSON format. + nogo_target_info = ctx.attr._nogo_target[NogoTargetInfo] + go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(target.label.name + ".facts") findings = ctx.actions.declare_file(target.label.name + ".findings") escapes = ctx.actions.declare_file(target.label.name + ".escapes") @@ -224,13 +263,13 @@ def _nogo_aspect_impl(target, ctx): ctx.actions.run( inputs = inputs, outputs = [facts, findings, escapes], - tools = depset(go_ctx.runfiles.to_list() + ctx.files._objdump_tool), - executable = ctx.files._nogo[0], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), + executable = ctx.files._nogo_check[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, arguments = go_ctx.nogo_args + [ "-binary=%s" % target_objfile.path, - "-objdump_tool=%s" % ctx.files._objdump_tool[0].path, + "-objdump_tool=%s" % ctx.files._nogo_objdump_tool[0].path, "-package=%s" % config_file.path, "-findings=%s" % findings.path, "-facts=%s" % facts.path, @@ -266,9 +305,22 @@ nogo_aspect = go_rule( "embed", ], attrs = { - "_nogo": attr.label(default = "//tools/nogo/check:check"), - "_nogo_stdlib": attr.label(default = "//tools/nogo:stdlib"), - "_objdump_tool": attr.label(default = "//tools/nogo:objdump_tool"), + "_nogo_check": attr.label( + default = "//tools/nogo/check:check", + cfg = "host", + ), + "_nogo_stdlib": attr.label( + default = "//tools/nogo:stdlib", + cfg = "host", + ), + "_nogo_objdump_tool": attr.label( + default = "//tools/nogo:objdump_tool", + cfg = "host", + ), + "_nogo_target": attr.label( + default = "//tools/nogo:target", + cfg = "target", + ), }, ) diff --git a/tools/nogo/gentest.sh b/tools/nogo/gentest.sh index 033da11ad..0a762f9f6 100755 --- a/tools/nogo/gentest.sh +++ b/tools/nogo/gentest.sh @@ -34,6 +34,7 @@ for filename in "$@"; do continue fi while read -r line; do + line="${line@Q}" violations=$((${violations}+1)); echo "echo -e '\\033[0;31m${line}\\033[0;31m\\033[0m'" >> "${output}" done < "${filename}" diff --git a/tools/nogo/io_bazel_rules_go-visibility.patch b/tools/nogo/io_bazel_rules_go-visibility.patch deleted file mode 100644 index 6b64b2e85..000000000 --- a/tools/nogo/io_bazel_rules_go-visibility.patch +++ /dev/null @@ -1,25 +0,0 @@ -diff --git a/third_party/org_golang_x_tools-extras.patch b/third_party/org_golang_x_tools-extras.patch -index 133fbccc..5f0d9a47 100644 ---- a/third_party/org_golang_x_tools-extras.patch -+++ b/third_party/org_golang_x_tools-extras.patch -@@ -32,7 +32,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ - - go_library( - name = "go_default_library", --@@ -14,6 +14,23 @@ -+@@ -14,6 +14,20 @@ - ], - ) - -@@ -43,10 +43,7 @@ diff -urN c/go/analysis/internal/facts/BUILD.bazel d/go/analysis/internal/facts/ - + "imports.go", - + ], - + importpath = "golang.org/x/tools/go/analysis/internal/facts", --+ visibility = [ --+ "//go/analysis:__subpackages__", --+ "@io_bazel_rules_go//go/tools/builders:__pkg__", --+ ], -++ visibility = ["//visibility:public"], - + deps = [ - + "//go/analysis:go_tool_library", - + "//go/types/objectpath:go_tool_library", diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index 120fdcff5..e19e3c237 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -264,12 +264,17 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str // Closure to check a single package. allFindings := make([]string, 0) stdlibFacts := make(map[string][]byte) + stdlibErrs := make(map[string]error) var checkOne func(pkg string) error // Recursive. checkOne = func(pkg string) error { // Is this already done? if _, ok := stdlibFacts[pkg]; ok { return nil } + // Did this fail previously? + if _, ok := stdlibErrs[pkg]; ok { + return nil + } // Lookup the configuration. config, ok := packages[pkg] @@ -283,6 +288,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str // If there's no binary for this package, it is likely // not built with the distribution. That's fine, we can // just skip analysis. + stdlibErrs[pkg] = err return nil } @@ -299,6 +305,7 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str if err != nil { // If we can't analyze a package from the standard library, // then we skip it. It will simply not have any findings. + stdlibErrs[pkg] = err return nil } stdlibFacts[pkg] = factData @@ -312,7 +319,9 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str // to evaluate in the order provided here. We do ensure however, that // all packages are evaluated. for pkg := range packages { - checkOne(pkg) + if err := checkOne(pkg); err != nil { + return nil, nil, err + } } // Sanity check. @@ -326,6 +335,11 @@ func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]str return nil, nil, fmt.Errorf("error saving stdlib facts: %w", err) } + // Write out all errors. + for pkg, err := range stdlibErrs { + log.Printf("WARNING: error while processing %v: %v", pkg, err) + } + // Return all findings. return allFindings, factData, nil } @@ -522,15 +536,15 @@ func Main() { findings, factData, err = checkPackage(c, analyzerConfig, nil) // Do we need to do escape analysis? if *escapesOutput != "" { - escapes, _, err := checkPackage(c, escapesConfig, nil) - if err != nil { - log.Fatalf("error performing escape analysis: %v", err) - } f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { log.Fatalf("unable to open output %q: %v", *escapesOutput, err) } defer f.Close() + escapes, _, err := checkPackage(c, escapesConfig, nil) + if err != nil { + log.Fatalf("error performing escape analysis: %v", err) + } for _, escape := range escapes { fmt.Fprintf(f, "%s\n", escape) } diff --git a/tools/rules_go.patch b/tools/rules_go.patch new file mode 100644 index 000000000..5e1e87084 --- /dev/null +++ b/tools/rules_go.patch @@ -0,0 +1,14 @@ +diff --git a/go/private/rules/test.bzl b/go/private/rules/test.bzl +index 17516ad7..76b6c68c 100644 +--- a/go/private/rules/test.bzl ++++ b/go/private/rules/test.bzl +@@ -121,9 +121,6 @@ def _go_test_impl(ctx): + ) + + test_gc_linkopts = gc_linkopts(ctx) +- if not go.mode.debug: +- # Disable symbol table and DWARF generation for test binaries. +- test_gc_linkopts.extend(["-s", "-w"]) + + # Now compile the test binary itself + test_library = GoLibrary( diff --git a/website/BUILD b/website/BUILD index 6d92d9103..f3642b903 100644 --- a/website/BUILD +++ b/website/BUILD @@ -1,17 +1,15 @@ load("//tools:defs.bzl", "bzl_library", "pkg_tar") load("//website:defs.bzl", "doc", "docs") +load("//images:defs.bzl", "docker_image") package(licenses = ["notice"]) -# website is the full container image. Note that this actually just collects -# other dependendcies and runs Docker locally to import and tag the image. -sh_binary( +docker_image( name = "website", - srcs = ["import.sh"], data = [":files"], - tags = [ - "local", - "manual", + statements = [ + "EXPOSE 8080/tcp", + 'ENTRYPOINT ["/server"]', ], ) diff --git a/website/import.sh b/website/import.sh deleted file mode 100755 index e1350e83d..000000000 --- a/website/import.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# Copyright 2018 The gVisor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -xeuo pipefail - -if [[ -d $0.runfiles ]]; then - cd $0.runfiles -fi - -exec docker import \ - -c "EXPOSE 8080/tcp" \ - -c "ENTRYPOINT [\"/server\"]" \ - $(find . -name files.tgz) \ - gvisor.dev/images/website |